krislette commited on
Commit
bf511a8
·
1 Parent(s): 98a8573

Created server files

Browse files
Files changed (2) hide show
  1. api/main.py +182 -0
  2. api/requirements.txt +6 -0
api/main.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI backend that loads the saved model artifacts
3
+ and serves predictions for the Discover Mode of the Gairaigo Map.
4
+
5
+ Endpoints:
6
+ GET /health — liveness check
7
+ GET /languages — returns the 3 classifiable languages
8
+ POST /predict — classifies a katakana word
9
+
10
+ Usage:
11
+ uvicorn main:app --reload --port 8000
12
+ """
13
+
14
+ import re
15
+ import numpy as np
16
+ import joblib
17
+ from pathlib import Path
18
+ from contextlib import asynccontextmanager
19
+ from fastapi import FastAPI, HTTPException
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from pydantic import BaseModel, field_validator
22
+
23
+
24
+ # Paths
25
+ BASE_DIR = Path(__file__).parent
26
+ MODELS_DIR = BASE_DIR.parent / "models"
27
+
28
+ MODEL_PATH = MODELS_DIR / "model.joblib"
29
+ VECTORIZER_PATH = MODELS_DIR / "vectorizer.joblib"
30
+ ENCODER_PATH = MODELS_DIR / "encoder.joblib"
31
+
32
+ # Katakana validation
33
+ KATAKANA_RE = re.compile(r"^[\u30A0-\u30FF\u30FC\u30FB\u30FE\u30FD]+$")
34
+
35
+
36
+ def is_katakana(text: str) -> bool:
37
+ return bool(KATAKANA_RE.match(text.strip()))
38
+
39
+
40
+ # Language metadata for the three classifiable languages
41
+ # (mirrors what the frontend needs to highlight the map)
42
+ LANGUAGE_META = {
43
+ "English": {"iso2": "GB", "country": "United Kingdom", "color": "#4a90d9"},
44
+ "French": {"iso2": "FR", "country": "France", "color": "#e85d5d"},
45
+ "German": {"iso2": "DE", "country": "Germany", "color": "#f0a500"},
46
+ }
47
+
48
+
49
+ # Lifespan, load model artifacts once on startup
50
+ artifacts: dict = {}
51
+
52
+
53
+ @asynccontextmanager
54
+ async def lifespan(app: FastAPI):
55
+ for path in (MODEL_PATH, VECTORIZER_PATH, ENCODER_PATH):
56
+ if not path.exists():
57
+ raise RuntimeError(
58
+ f"Model artifact not found: {path}\n"
59
+ "Run `python -m scripts.train` from your kataklassifer project first."
60
+ )
61
+ artifacts["model"] = joblib.load(MODEL_PATH)
62
+ artifacts["vectorizer"] = joblib.load(VECTORIZER_PATH)
63
+ artifacts["encoder"] = joblib.load(ENCODER_PATH)
64
+ print("✓ Model artifacts loaded")
65
+ yield
66
+ artifacts.clear()
67
+
68
+
69
+ # App
70
+ app = FastAPI(
71
+ title="Gairaigo Map API",
72
+ description="Classifies Japanese katakana loanwords into English, French, or German.",
73
+ version="1.0.0",
74
+ lifespan=lifespan,
75
+ )
76
+
77
+ app.add_middleware(
78
+ CORSMiddleware,
79
+ allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
80
+ allow_methods=["GET", "POST"],
81
+ allow_headers=["*"],
82
+ )
83
+
84
+
85
+ # Schemas
86
+ class PredictRequest(BaseModel):
87
+ word: str
88
+
89
+ @field_validator("word")
90
+ @classmethod
91
+ def must_be_katakana(cls, v: str) -> str:
92
+ v = v.strip()
93
+ if not v:
94
+ raise ValueError("Word must not be empty.")
95
+ if not is_katakana(v):
96
+ raise ValueError(
97
+ "Input must be a katakana string (e.g. コーヒー). "
98
+ "Hiragana, kanji, or romaji are not supported."
99
+ )
100
+ return v
101
+
102
+
103
+ class LanguageResult(BaseModel):
104
+ language: str
105
+ country: str
106
+ iso2: str
107
+ confidence: float
108
+ color: str
109
+
110
+
111
+ class PredictResponse(BaseModel):
112
+ word: str
113
+ prediction: LanguageResult
114
+ all_scores: list[LanguageResult]
115
+
116
+
117
+ # Helpers
118
+ def softmax(scores: np.ndarray) -> np.ndarray:
119
+ """Convert raw SVM decision scores to a probability-like distribution."""
120
+ exp_scores = np.exp(scores - np.max(scores))
121
+ return exp_scores / exp_scores.sum()
122
+
123
+
124
+ def classify(word: str) -> PredictResponse:
125
+ model = artifacts["model"]
126
+ vectorizer = artifacts["vectorizer"]
127
+ encoder = artifacts["encoder"]
128
+
129
+ X = vectorizer.transform([word])
130
+
131
+ # decision_function returns shape (1, n_classes) for multi-class LinearSVC
132
+ decision_scores = model.decision_function(X)[0] # shape: (3,)
133
+ confidences = softmax(decision_scores) # normalized to sum=1
134
+ _ = int(np.argmax(confidences))
135
+ classes = encoder.classes_ # e.g. ["English", "French", "German"]
136
+
137
+ all_scores = [
138
+ LanguageResult(
139
+ language=classes[i],
140
+ country=LANGUAGE_META[classes[i]]["country"],
141
+ iso2=LANGUAGE_META[classes[i]]["iso2"],
142
+ confidence=round(float(confidences[i]), 4),
143
+ color=LANGUAGE_META[classes[i]]["color"],
144
+ )
145
+ for i in range(len(classes))
146
+ ]
147
+
148
+ # Sort descending by confidence for the frontend
149
+ all_scores.sort(key=lambda r: r.confidence, reverse=True)
150
+
151
+ return PredictResponse(
152
+ word=word,
153
+ prediction=all_scores[0],
154
+ all_scores=all_scores,
155
+ )
156
+
157
+
158
+ # Routes
159
+ @app.get("/health", tags=["Meta"])
160
+ def health():
161
+ return {"status": "ok", "model_loaded": bool(artifacts)}
162
+
163
+
164
+ @app.get("/languages", tags=["Meta"])
165
+ def get_languages():
166
+ """Returns metadata for the 3 classifiable donor languages."""
167
+ return {lang: meta for lang, meta in LANGUAGE_META.items()}
168
+
169
+
170
+ @app.post("/predict", response_model=PredictResponse, tags=["Classification"])
171
+ def predict(body: PredictRequest):
172
+ """
173
+ Classify a single katakana loanword.
174
+
175
+ - **word**: A katakana string, e.g. `コーヒー`, `アルバイト`, `テレビ`
176
+ - Returns the predicted donor language with a softmax confidence score,
177
+ plus all 3 languages ranked by confidence.
178
+ """
179
+ try:
180
+ return classify(body.word)
181
+ except Exception as e:
182
+ raise HTTPException(status_code=500, detail=str(e))
api/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.7.0
4
+ joblib>=1.4.0
5
+ numpy>=1.26.0
6
+ scikit-learn>=1.4.0