File size: 5,153 Bytes
caf26c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Loads the saved model artifacts and predicts the origin language of one or
more katakana loanwords provided by the user.

The model, vectorizer, and label encoder must already exist in models/.
Run scripts/train.py first if they are not there yet.

Usage (interactive prompt):
  python scripts/predict.py

Usage (pass words directly as arguments):
  python scripts/predict.py テレビ γ‚³γƒΌγƒ’γƒΌ γ‚’γƒ«γƒγ‚€γƒˆ

Output example:
  テレビ       -> English
  γ‚³γƒΌγƒ’γƒΌ     -> Dutch
  γ‚’γƒ«γƒγ‚€γƒˆ   -> German
"""

import sys
import os
import re

# Allow imports from the project root regardless of where the script is called from
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import joblib

MODEL_DIR = "models"

# Matches strings made entirely of katakana characters (same rule as loader.py)
KATAKANA_PATTERN = re.compile(r"^[\u30A0-\u30FF]+$")


def load_artifacts():
    """
    Load the three saved model artifacts from the models/ directory.

    Exits with a clear message if the files are not found, so the user
    knows to run train.py first rather than seeing a raw FileNotFoundError.
    """
    model_path = os.path.join(MODEL_DIR, "model.joblib")
    vectorizer_path = os.path.join(MODEL_DIR, "vectorizer.joblib")
    encoder_path = os.path.join(MODEL_DIR, "encoder.joblib")

    missing = [
        p for p in (model_path, vectorizer_path, encoder_path) if not os.path.exists(p)
    ]
    if missing:
        print("\n[predict] Error: the following model files were not found:")
        for path in missing:
            print(f"           {path}")
        print("\n          Run scripts/train.py first to generate them.\n")
        sys.exit(1)

    model = joblib.load(model_path)
    vectorizer = joblib.load(vectorizer_path)
    label_encoder = joblib.load(encoder_path)
    return model, vectorizer, label_encoder


def predict(
    words: list[str], model, vectorizer, label_encoder
) -> list[tuple[str, str]]:
    """
    Predict the origin language for each katakana word.

    Words that are not pure katakana are flagged as invalid and skipped
    rather than silently passed to the model with meaningless features.

    Args:
        words         : List of katakana strings to classify.
        model         : Fitted LinearSVC.
        vectorizer    : Fitted TfidfVectorizer.
        label_encoder : Fitted LabelEncoder.

    Returns:
        List of (word, prediction) tuples. Invalid words get '(invalid input)'
        as their prediction so the output table stays aligned.
    """
    results = []

    # Separate valid katakana words from invalid inputs
    valid_words = [w for w in words if KATAKANA_PATTERN.match(w)]
    invalid_words = {w for w in words if not KATAKANA_PATTERN.match(w)}

    if valid_words:
        # Vectorize all valid words in one batch for efficiency
        X = vectorizer.transform(valid_words)
        predictions = label_encoder.inverse_transform(model.predict(X))
        word_to_pred = dict(zip(valid_words, predictions))
    else:
        word_to_pred = {}

    for word in words:
        if word in invalid_words:
            results.append((word, "(invalid β€” not pure katakana)"))
        else:
            results.append((word, word_to_pred[word]))

    return results


def print_results(results: list[tuple[str, str]]):
    """Print predictions in a clean aligned two-column table."""
    if not results:
        return

    max_word_len = max(len(word) for word, _ in results)

    print()
    print(f'  {"Katakana":<{max_word_len + 2}}  Predicted Origin Language')
    print(f'  {"-" * (max_word_len + 2)}  {"-" * 28}')
    for word, prediction in results:
        print(f"  {word:<{max_word_len + 2}}  {prediction}")
    print()


def interactive_mode(model, vectorizer, label_encoder):
    """
    Run a loop that prompts the user for katakana words and prints predictions.

    Multiple words can be entered space-separated on a single line.
    Type 'quit' or press Ctrl+C to exit.
    """
    print("\n[predict] Gairaigo Origin Classifier β€” interactive mode")
    print('          Enter katakana words (space-separated). Type "quit" to exit.\n')

    while True:
        try:
            raw = input("  >> ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\n  γ˜γ‚ƒγ‚γ­οΌ\n")
            break

        if raw.lower() in ("quit", "exit", "q"):
            print("  γ˜γ‚ƒγ‚γ­οΌ\n")
            break

        if not raw:
            continue

        words = raw.split()
        results = predict(words, model, vectorizer, label_encoder)
        print_results(results)


def main():
    model, vectorizer, label_encoder = load_artifacts()
    print(f"\n[predict] Model loaded. Known classes: {list(label_encoder.classes_)}")

    # If words were passed as CLI arguments, classify them and exit
    if len(sys.argv) > 1:
        words = sys.argv[1:]
        results = predict(words, model, vectorizer, label_encoder)
        print_results(results)
    else:
        # No arguments, launch interactive prompt
        interactive_mode(model, vectorizer, label_encoder)


if __name__ == "__main__":
    main()