File size: 5,484 Bytes
8ff1b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
FastText model training.

FastText is preferred over Word2Vec because:
- Better handling of typos and misspellings (common in reviews)
- Can generate vectors for out-of-vocabulary words
- Uses character n-grams internally
"""

import logging
from pathlib import Path

from gensim.models import FastText

from .config import MODELS_DIR, SETTINGS

logger = logging.getLogger(__name__)


class FastTextTrainer:
    """
    Trains FastText word embeddings on review corpus.
    """

    def __init__(
        self,
        vector_size: int | None = None,
        window: int | None = None,
        min_count: int | None = None,
        epochs: int | None = None,
        workers: int | None = None,
    ):
        """
        Initialize trainer with hyperparameters.

        Args:
            vector_size: Dimensionality of word vectors
            window: Context window size
            min_count: Minimum word frequency
            epochs: Number of training iterations
            workers: Number of worker threads
        """
        self.vector_size = vector_size or SETTINGS["fasttext_vector_size"]
        self.window = window or SETTINGS["fasttext_window"]
        self.min_count = min_count or SETTINGS["fasttext_min_count"]
        self.epochs = epochs or SETTINGS["fasttext_epochs"]
        self.workers = workers or SETTINGS["fasttext_workers"]

        self.model: FastText | None = None

    def train(self, sentences: list[list[str]]) -> FastText:
        """
        Train FastText model on tokenized sentences.

        Args:
            sentences: List of tokenized documents (output from preprocessor)

        Returns:
            Trained FastText model
        """
        logger.info(
            f"Training FastText model: "
            f"vector_size={self.vector_size}, window={self.window}, "
            f"min_count={self.min_count}, epochs={self.epochs}"
        )
        logger.info(f"Training on {len(sentences)} documents")

        self.model = FastText(
            sentences=sentences,
            vector_size=self.vector_size,
            window=self.window,
            min_count=self.min_count,
            epochs=self.epochs,
            workers=self.workers,
            sg=1,  # Skip-gram (better for semantic similarity)
            min_n=3,  # Minimum character n-gram length
            max_n=6,  # Maximum character n-gram length
        )

        vocab_size = len(self.model.wv)
        logger.info(f"Training complete. Vocabulary size: {vocab_size}")

        return self.model

    def save(self, path: Path | str | None = None) -> Path:
        """
        Save trained model.

        Args:
            path: Save path (default: models/fasttext.model)

        Returns:
            Path where model was saved
        """
        if self.model is None:
            raise ValueError("No model to save. Train first.")

        path = Path(path) if path else MODELS_DIR / "fasttext.model"
        self.model.save(str(path))
        logger.info(f"Saved model to {path}")
        return path

    def load(self, path: Path | str | None = None) -> FastText:
        """
        Load model from file.

        Args:
            path: Model path (default: models/fasttext.model)

        Returns:
            Loaded FastText model
        """
        path = Path(path) if path else MODELS_DIR / "fasttext.model"

        if not path.exists():
            raise FileNotFoundError(f"Model not found at {path}")

        self.model = FastText.load(str(path))
        vocab_size = len(self.model.wv)
        logger.info(f"Loaded model from {path}. Vocabulary size: {vocab_size}")
        return self.model

    def get_similar(
        self,
        word: str,
        topn: int = 10,
    ) -> list[tuple[str, float]]:
        """
        Get most similar words to a given word.

        Args:
            word: Query word
            topn: Number of results

        Returns:
            List of (word, similarity) tuples
        """
        if self.model is None:
            raise ValueError("No model loaded. Train or load first.")

        # Normalize word (space to underscore for phrases)
        word_normalized = word.lower().replace(" ", "_")

        try:
            return self.model.wv.most_similar(word_normalized, topn=topn)
        except KeyError:
            logger.warning(f"Word '{word}' not in vocabulary")
            return []

    def get_similarity(self, word1: str, word2: str) -> float:
        """
        Get similarity between two words.

        Args:
            word1: First word
            word2: Second word

        Returns:
            Cosine similarity (-1 to 1)
        """
        if self.model is None:
            raise ValueError("No model loaded. Train or load first.")

        w1 = word1.lower().replace(" ", "_")
        w2 = word2.lower().replace(" ", "_")

        try:
            return float(self.model.wv.similarity(w1, w2))
        except KeyError as e:
            logger.warning(f"Word not in vocabulary: {e}")
            return 0.0

    def word_in_vocab(self, word: str) -> bool:
        """Check if word is in vocabulary."""
        if self.model is None:
            return False

        word_normalized = word.lower().replace(" ", "_")
        return word_normalized in self.model.wv

    def get_vocab_words(self) -> list[str]:
        """Get all words in vocabulary."""
        if self.model is None:
            return []
        return list(self.model.wv.key_to_index.keys())