Spaces:
Sleeping
Sleeping
File size: 5,153 Bytes
caf26c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
Loads the saved model artifacts and predicts the origin language of one or
more katakana loanwords provided by the user.
The model, vectorizer, and label encoder must already exist in models/.
Run scripts/train.py first if they are not there yet.
Usage (interactive prompt):
python scripts/predict.py
Usage (pass words directly as arguments):
python scripts/predict.py γγ¬γ γ³γΌγγΌ γ’γ«γγ€γ
Output example:
γγ¬γ -> English
γ³γΌγγΌ -> Dutch
γ’γ«γγ€γ -> German
"""
import sys
import os
import re
# Allow imports from the project root regardless of where the script is called from
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import joblib
MODEL_DIR = "models"
# Matches strings made entirely of katakana characters (same rule as loader.py)
KATAKANA_PATTERN = re.compile(r"^[\u30A0-\u30FF]+$")
def load_artifacts():
"""
Load the three saved model artifacts from the models/ directory.
Exits with a clear message if the files are not found, so the user
knows to run train.py first rather than seeing a raw FileNotFoundError.
"""
model_path = os.path.join(MODEL_DIR, "model.joblib")
vectorizer_path = os.path.join(MODEL_DIR, "vectorizer.joblib")
encoder_path = os.path.join(MODEL_DIR, "encoder.joblib")
missing = [
p for p in (model_path, vectorizer_path, encoder_path) if not os.path.exists(p)
]
if missing:
print("\n[predict] Error: the following model files were not found:")
for path in missing:
print(f" {path}")
print("\n Run scripts/train.py first to generate them.\n")
sys.exit(1)
model = joblib.load(model_path)
vectorizer = joblib.load(vectorizer_path)
label_encoder = joblib.load(encoder_path)
return model, vectorizer, label_encoder
def predict(
words: list[str], model, vectorizer, label_encoder
) -> list[tuple[str, str]]:
"""
Predict the origin language for each katakana word.
Words that are not pure katakana are flagged as invalid and skipped
rather than silently passed to the model with meaningless features.
Args:
words : List of katakana strings to classify.
model : Fitted LinearSVC.
vectorizer : Fitted TfidfVectorizer.
label_encoder : Fitted LabelEncoder.
Returns:
List of (word, prediction) tuples. Invalid words get '(invalid input)'
as their prediction so the output table stays aligned.
"""
results = []
# Separate valid katakana words from invalid inputs
valid_words = [w for w in words if KATAKANA_PATTERN.match(w)]
invalid_words = {w for w in words if not KATAKANA_PATTERN.match(w)}
if valid_words:
# Vectorize all valid words in one batch for efficiency
X = vectorizer.transform(valid_words)
predictions = label_encoder.inverse_transform(model.predict(X))
word_to_pred = dict(zip(valid_words, predictions))
else:
word_to_pred = {}
for word in words:
if word in invalid_words:
results.append((word, "(invalid β not pure katakana)"))
else:
results.append((word, word_to_pred[word]))
return results
def print_results(results: list[tuple[str, str]]):
"""Print predictions in a clean aligned two-column table."""
if not results:
return
max_word_len = max(len(word) for word, _ in results)
print()
print(f' {"Katakana":<{max_word_len + 2}} Predicted Origin Language')
print(f' {"-" * (max_word_len + 2)} {"-" * 28}')
for word, prediction in results:
print(f" {word:<{max_word_len + 2}} {prediction}")
print()
def interactive_mode(model, vectorizer, label_encoder):
"""
Run a loop that prompts the user for katakana words and prints predictions.
Multiple words can be entered space-separated on a single line.
Type 'quit' or press Ctrl+C to exit.
"""
print("\n[predict] Gairaigo Origin Classifier β interactive mode")
print(' Enter katakana words (space-separated). Type "quit" to exit.\n')
while True:
try:
raw = input(" >> ").strip()
except (KeyboardInterrupt, EOFError):
print("\n γγγγοΌ\n")
break
if raw.lower() in ("quit", "exit", "q"):
print(" γγγγοΌ\n")
break
if not raw:
continue
words = raw.split()
results = predict(words, model, vectorizer, label_encoder)
print_results(results)
def main():
model, vectorizer, label_encoder = load_artifacts()
print(f"\n[predict] Model loaded. Known classes: {list(label_encoder.classes_)}")
# If words were passed as CLI arguments, classify them and exit
if len(sys.argv) > 1:
words = sys.argv[1:]
results = predict(words, model, vectorizer, label_encoder)
print_results(results)
else:
# No arguments, launch interactive prompt
interactive_mode(model, vectorizer, label_encoder)
if __name__ == "__main__":
main()
|