Mirae Kang commited on
Commit ·
0f0ce9b
1
Parent(s): e317d56
fix: debug model selection, #22
Browse files- .gitattributes +1 -0
- Dockerfile +4 -0
- README.md +14 -1
- configs/model_catalog.yaml +3 -2
- configs/suggested_videos.yaml +4 -4
- frontend/src/api/client.ts +21 -4
- frontend/src/components/SuggestedRail.tsx +12 -1
- frontend/src/pages/SettingsPage.tsx +8 -1
- scripts/materialize_finetuned_weights.py +56 -0
- src/api/main.py +9 -2
- src/api/routes/models.py +37 -25
- src/api/schemas.py +4 -0
- src/service/model_service.py +51 -6
- tests/test_api.py +23 -0
.gitattributes
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 2 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
models/finetuned_hf/** filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
|
@@ -44,6 +44,10 @@ COPY configs/ configs/
|
|
| 44 |
COPY src/ src/
|
| 45 |
COPY models/final_model.joblib models/final_model.joblib
|
| 46 |
COPY models/finetuned_hf/ models/finetuned_hf/
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
COPY --from=frontend-build /app/frontend/dist frontend/dist
|
| 48 |
COPY .env.example .env.example
|
| 49 |
|
|
|
|
| 44 |
COPY src/ src/
|
| 45 |
COPY models/final_model.joblib models/final_model.joblib
|
| 46 |
COPY models/finetuned_hf/ models/finetuned_hf/
|
| 47 |
+
COPY scripts/materialize_finetuned_weights.py scripts/materialize_finetuned_weights.py
|
| 48 |
+
RUN if [ "$INSTALL_HF" = "1" ]; then \
|
| 49 |
+
uv run python scripts/materialize_finetuned_weights.py || true; \
|
| 50 |
+
fi
|
| 51 |
COPY --from=frontend-build /app/frontend/dist frontend/dist
|
| 52 |
COPY .env.example .env.example
|
| 53 |
|
README.md
CHANGED
|
@@ -46,6 +46,18 @@ uv run uvicorn src.api.main:app --reload --port 8000
|
|
| 46 |
|
| 47 |
Verify HF deps: `uv run python -c "import transformers; print('ok')"`.
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
| Resource | URL |
|
| 50 |
|----------|-----|
|
| 51 |
| Swagger | http://localhost:8000/docs |
|
|
@@ -60,7 +72,8 @@ Verify HF deps: `uv run python -c "import transformers; print('ok')"`.
|
|
| 60 |
| `GET` | `/videos/suggested` | Metadata for right-rail videos (from `configs/suggested_videos.yaml`) |
|
| 61 |
| `GET` | `/models` | Available models |
|
| 62 |
| `GET` | `/models/status` | Per-model availability (HF deps, local weights) |
|
| 63 |
-
| `
|
|
|
|
| 64 |
|
| 65 |
Set `YOUTUBE_API_KEY` in `.env` for real comments and suggested-video thumbnails.
|
| 66 |
|
|
|
|
| 46 |
|
| 47 |
Verify HF deps: `uv run python -c "import transformers; print('ok')"`.
|
| 48 |
|
| 49 |
+
**Fine-tuned (local HF)** needs real weight files in `models/finetuned_hf/` (not the 134-byte Git LFS pointer). **You do not need Git LFS** if you use:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
uv sync --extra hf
|
| 53 |
+
uv run python scripts/materialize_finetuned_weights.py
|
| 54 |
+
ls -lh models/finetuned_hf/model.safetensors # should be ~250 MB+
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Optional (if the team pushed weights with Git LFS): `brew install git-lfs`, then `git lfs install` and `git lfs pull`.
|
| 58 |
+
|
| 59 |
+
Without local weights, the API falls back to `martin-ha/toxic-comment-model` from Hugging Face Hub when you select this model.
|
| 60 |
+
|
| 61 |
| Resource | URL |
|
| 62 |
|----------|-----|
|
| 63 |
| Swagger | http://localhost:8000/docs |
|
|
|
|
| 72 |
| `GET` | `/videos/suggested` | Metadata for right-rail videos (from `configs/suggested_videos.yaml`) |
|
| 73 |
| `GET` | `/models` | Available models |
|
| 74 |
| `GET` | `/models/status` | Per-model availability (HF deps, local weights) |
|
| 75 |
+
| `POST` | `/models/select` | Switch active model `{"model_name": "..."}` (preferred) |
|
| 76 |
+
| `PUT` | `/model/{name}` | Legacy path-based model switch |
|
| 77 |
|
| 78 |
Set `YOUTUBE_API_KEY` in `.env` for real comments and suggested-video thumbnails.
|
| 79 |
|
configs/model_catalog.yaml
CHANGED
|
@@ -37,7 +37,8 @@
|
|
| 37 |
type: hf_local
|
| 38 |
icon: "✨"
|
| 39 |
model_path: models/finetuned_hf
|
| 40 |
-
|
|
|
|
| 41 |
speed: "Hardware dependent"
|
| 42 |
accuracy: "TBD"
|
| 43 |
-
requires: "uv sync --extra hf"
|
|
|
|
| 37 |
type: hf_local
|
| 38 |
icon: "✨"
|
| 39 |
model_path: models/finetuned_hf
|
| 40 |
+
hub_fallback: martin-ha/toxic-comment-model
|
| 41 |
+
description: "Local DistilBERT folder (models/finetuned_hf). Materialize weights if missing."
|
| 42 |
speed: "Hardware dependent"
|
| 43 |
accuracy: "TBD"
|
| 44 |
+
requires: "uv sync --extra hf; uv run python scripts/materialize_finetuned_weights.py"
|
configs/suggested_videos.yaml
CHANGED
|
@@ -9,7 +9,7 @@ videos:
|
|
| 9 |
note: 3Blue1Brown — embed-friendly educational
|
| 10 |
- id: dQw4w9WgXcQ
|
| 11 |
note: Rick Astley — usually embeddable
|
| 12 |
-
- id:
|
| 13 |
-
note:
|
| 14 |
-
- id:
|
| 15 |
-
note:
|
|
|
|
| 9 |
note: 3Blue1Brown — embed-friendly educational
|
| 10 |
- id: dQw4w9WgXcQ
|
| 11 |
note: Rick Astley — usually embeddable
|
| 12 |
+
- id: M7lc1UVf-VE
|
| 13 |
+
note: YouTube Developers — designed for embedding
|
| 14 |
+
- id: 8aGhZQkoFbQ
|
| 15 |
+
note: What is an API — tech talk, comments on
|
frontend/src/api/client.ts
CHANGED
|
@@ -54,13 +54,30 @@ export function getModels() {
|
|
| 54 |
return request<{ available: string[]; active: string }>("/models");
|
| 55 |
}
|
| 56 |
|
| 57 |
-
export function getModelsStatus() {
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
export function setModel(name: string) {
|
| 62 |
-
return request<{ message: string; model: string }>(
|
| 63 |
-
method: "
|
|
|
|
| 64 |
});
|
| 65 |
}
|
| 66 |
|
|
|
|
| 54 |
return request<{ available: string[]; active: string }>("/models");
|
| 55 |
}
|
| 56 |
|
| 57 |
+
export async function getModelsStatus() {
|
| 58 |
+
try {
|
| 59 |
+
return await request<{ models: ModelStatusEntry[]; active: string }>("/models/status");
|
| 60 |
+
} catch (e) {
|
| 61 |
+
if (e instanceof Error && e.message.toLowerCase().includes("not found")) {
|
| 62 |
+
const legacy = await getModels();
|
| 63 |
+
return {
|
| 64 |
+
active: legacy.active,
|
| 65 |
+
models: legacy.available.map((name) => ({
|
| 66 |
+
name,
|
| 67 |
+
available: true,
|
| 68 |
+
reason: null,
|
| 69 |
+
type: "unknown",
|
| 70 |
+
})),
|
| 71 |
+
};
|
| 72 |
+
}
|
| 73 |
+
throw e;
|
| 74 |
+
}
|
| 75 |
}
|
| 76 |
|
| 77 |
export function setModel(name: string) {
|
| 78 |
+
return request<{ message: string; model: string }>("/models/select", {
|
| 79 |
+
method: "POST",
|
| 80 |
+
body: JSON.stringify({ model_name: name }),
|
| 81 |
});
|
| 82 |
}
|
| 83 |
|
frontend/src/components/SuggestedRail.tsx
CHANGED
|
@@ -19,7 +19,18 @@ export function SuggestedRail({ videos, activeId, loadingId, onSelect }: Props)
|
|
| 19 |
onClick={() => onSelect(v)}
|
| 20 |
disabled={loadingId === v.id}
|
| 21 |
>
|
| 22 |
-
<img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
<div className="suggested-info">
|
| 24 |
<p className="suggested-title">{v.title}</p>
|
| 25 |
<p className="suggested-channel">{v.channel_title}</p>
|
|
|
|
| 19 |
onClick={() => onSelect(v)}
|
| 20 |
disabled={loadingId === v.id}
|
| 21 |
>
|
| 22 |
+
<img
|
| 23 |
+
src={v.thumbnail_url}
|
| 24 |
+
alt=""
|
| 25 |
+
className="suggested-thumb"
|
| 26 |
+
onError={(e) => {
|
| 27 |
+
const img = e.currentTarget;
|
| 28 |
+
if (!img.dataset.fallback) {
|
| 29 |
+
img.dataset.fallback = "1";
|
| 30 |
+
img.src = `https://i.ytimg.com/vi/${v.id}/hqdefault.jpg`;
|
| 31 |
+
}
|
| 32 |
+
}}
|
| 33 |
+
/>
|
| 34 |
<div className="suggested-info">
|
| 35 |
<p className="suggested-title">{v.title}</p>
|
| 36 |
<p className="suggested-channel">{v.channel_title}</p>
|
frontend/src/pages/SettingsPage.tsx
CHANGED
|
@@ -12,6 +12,7 @@ export function SettingsPage() {
|
|
| 12 |
const [testError, setTestError] = useState<string | null>(null);
|
| 13 |
const [testing, setTesting] = useState(false);
|
| 14 |
const [message, setMessage] = useState<string | null>(null);
|
|
|
|
| 15 |
|
| 16 |
const loadStatus = () => {
|
| 17 |
getModelsStatus()
|
|
@@ -38,6 +39,7 @@ export function SettingsPage() {
|
|
| 38 |
return;
|
| 39 |
}
|
| 40 |
setMessage(null);
|
|
|
|
| 41 |
try {
|
| 42 |
await setModel(name);
|
| 43 |
setActive(name);
|
|
@@ -46,6 +48,8 @@ export function SettingsPage() {
|
|
| 46 |
} catch (e) {
|
| 47 |
setMessage(e instanceof Error ? e.message : "Failed to switch model");
|
| 48 |
loadStatus();
|
|
|
|
|
|
|
| 49 |
}
|
| 50 |
};
|
| 51 |
|
|
@@ -72,6 +76,9 @@ export function SettingsPage() {
|
|
| 72 |
HF models need <code>uv sync --extra hf</code> locally, or{" "}
|
| 73 |
<code>INSTALL_HF=1 docker compose build</code> in Docker.
|
| 74 |
</p>
|
|
|
|
|
|
|
|
|
|
| 75 |
<div className="model-list">
|
| 76 |
{modelStatus.map((m) => (
|
| 77 |
<label
|
|
@@ -82,7 +89,7 @@ export function SettingsPage() {
|
|
| 82 |
type="radio"
|
| 83 |
name="model"
|
| 84 |
checked={active === m.name}
|
| 85 |
-
disabled={!m.available}
|
| 86 |
onChange={() => void switchModel(m.name)}
|
| 87 |
/>
|
| 88 |
<span>
|
|
|
|
| 12 |
const [testError, setTestError] = useState<string | null>(null);
|
| 13 |
const [testing, setTesting] = useState(false);
|
| 14 |
const [message, setMessage] = useState<string | null>(null);
|
| 15 |
+
const [switching, setSwitching] = useState(false);
|
| 16 |
|
| 17 |
const loadStatus = () => {
|
| 18 |
getModelsStatus()
|
|
|
|
| 39 |
return;
|
| 40 |
}
|
| 41 |
setMessage(null);
|
| 42 |
+
setSwitching(true);
|
| 43 |
try {
|
| 44 |
await setModel(name);
|
| 45 |
setActive(name);
|
|
|
|
| 48 |
} catch (e) {
|
| 49 |
setMessage(e instanceof Error ? e.message : "Failed to switch model");
|
| 50 |
loadStatus();
|
| 51 |
+
} finally {
|
| 52 |
+
setSwitching(false);
|
| 53 |
}
|
| 54 |
};
|
| 55 |
|
|
|
|
| 76 |
HF models need <code>uv sync --extra hf</code> locally, or{" "}
|
| 77 |
<code>INSTALL_HF=1 docker compose build</code> in Docker.
|
| 78 |
</p>
|
| 79 |
+
{switching && (
|
| 80 |
+
<p className="hint">Switching model… HF models may take up to a minute on first load.</p>
|
| 81 |
+
)}
|
| 82 |
<div className="model-list">
|
| 83 |
{modelStatus.map((m) => (
|
| 84 |
<label
|
|
|
|
| 89 |
type="radio"
|
| 90 |
name="model"
|
| 91 |
checked={active === m.name}
|
| 92 |
+
disabled={!m.available || switching}
|
| 93 |
onChange={() => void switchModel(m.name)}
|
| 94 |
/>
|
| 95 |
<span>
|
scripts/materialize_finetuned_weights.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Download real HF weights into models/finetuned_hf (no Git LFS required).
|
| 3 |
+
|
| 4 |
+
The repo may only contain a Git LFS pointer for model.safetensors (~134 bytes).
|
| 5 |
+
This script saves a compatible DistilBERT toxic classifier from Hugging Face Hub
|
| 6 |
+
so "Fine-tuned (local HF)" can load offline after one download.
|
| 7 |
+
|
| 8 |
+
Run from repo root:
|
| 9 |
+
uv sync --extra hf
|
| 10 |
+
uv run python scripts/materialize_finetuned_weights.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
OUT_DIR = PROJECT_ROOT / "models" / "finetuned_hf"
|
| 20 |
+
# Same architecture family as notebook 08 (DistilBERT sequence classification)
|
| 21 |
+
HUB_ID = "martin-ha/toxic-comment-model"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> int:
|
| 25 |
+
weights = OUT_DIR / "model.safetensors"
|
| 26 |
+
if weights.is_file() and weights.stat().st_size > 1_000_000:
|
| 27 |
+
print(f"OK: {weights} already exists ({weights.stat().st_size // 1_000_000} MB)")
|
| 28 |
+
return 0
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 32 |
+
except ImportError:
|
| 33 |
+
print("Install HF deps first: uv sync --extra hf", file=sys.stderr)
|
| 34 |
+
return 1
|
| 35 |
+
|
| 36 |
+
print(f"Downloading {HUB_ID} into {OUT_DIR} …")
|
| 37 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
model = AutoModelForSequenceClassification.from_pretrained(HUB_ID)
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(HUB_ID)
|
| 40 |
+
model.save_pretrained(OUT_DIR)
|
| 41 |
+
tokenizer.save_pretrained(OUT_DIR)
|
| 42 |
+
|
| 43 |
+
meta = OUT_DIR / "model_metadata.json"
|
| 44 |
+
if not meta.exists():
|
| 45 |
+
meta.write_text(
|
| 46 |
+
'{"model_name":"DistilBERT (materialized from Hub)","note":"Run notebook 08 to replace with team weights"}\n',
|
| 47 |
+
encoding="utf-8",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
size_mb = weights.stat().st_size // 1_000_000 if weights.is_file() else 0
|
| 51 |
+
print(f"Done. {weights} ({size_mb} MB)")
|
| 52 |
+
return 0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
raise SystemExit(main())
|
src/api/main.py
CHANGED
|
@@ -87,7 +87,14 @@ app.include_router(predict.router)
|
|
| 87 |
app.include_router(videos.router)
|
| 88 |
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def _mount_frontend() -> None:
|
|
@@ -99,7 +106,7 @@ def _mount_frontend() -> None:
|
|
| 99 |
|
| 100 |
@app.get("/{full_path:path}", include_in_schema=False)
|
| 101 |
async def spa_fallback(full_path: str):
|
| 102 |
-
if
|
| 103 |
from fastapi import HTTPException
|
| 104 |
|
| 105 |
raise HTTPException(status_code=404, detail="Not found")
|
|
|
|
| 87 |
app.include_router(videos.router)
|
| 88 |
|
| 89 |
|
| 90 |
+
_API_PATH_ROOTS = frozenset(
|
| 91 |
+
{"models", "model", "videos", "predict", "health", "docs", "redoc", "openapi"}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _is_api_spa_path(full_path: str) -> bool:
|
| 96 |
+
root = full_path.split("/")[0] if full_path else ""
|
| 97 |
+
return root in _API_PATH_ROOTS
|
| 98 |
|
| 99 |
|
| 100 |
def _mount_frontend() -> None:
|
|
|
|
| 106 |
|
| 107 |
@app.get("/{full_path:path}", include_in_schema=False)
|
| 108 |
async def spa_fallback(full_path: str):
|
| 109 |
+
if _is_api_spa_path(full_path):
|
| 110 |
from fastapi import HTTPException
|
| 111 |
|
| 112 |
raise HTTPException(status_code=404, detail="Not found")
|
src/api/routes/models.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import time
|
|
|
|
| 2 |
|
| 3 |
from fastapi import APIRouter, HTTPException
|
| 4 |
|
| 5 |
-
from src.api.schemas import ModelInfo, ModelStatusEntry, ModelsStatusResponse
|
| 6 |
from src.api.services import get_service
|
| 7 |
from src.api.state import PROJECT_ROOT, get_state
|
| 8 |
from src.service.model_service import AVAILABLE_MODELS, ModelService, check_model_availability
|
|
@@ -10,6 +11,33 @@ from src.service.model_service import AVAILABLE_MODELS, ModelService, check_mode
|
|
| 10 |
router = APIRouter(tags=["Model"])
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@router.get("/model-info", response_model=ModelInfo)
|
| 14 |
async def get_model_info():
|
| 15 |
service = get_service()
|
|
@@ -50,29 +78,13 @@ async def list_models():
|
|
| 50 |
return {"available": list(AVAILABLE_MODELS.keys()), "active": state["model_name"]}
|
| 51 |
|
| 52 |
|
| 53 |
-
@router.
|
| 54 |
-
async def
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
status_code=400,
|
| 58 |
-
detail=f"Model '{model_name}' not available. Options: {list(AVAILABLE_MODELS.keys())}",
|
| 59 |
-
)
|
| 60 |
|
| 61 |
-
available, reason = check_model_availability(model_name, PROJECT_ROOT)
|
| 62 |
-
if not available:
|
| 63 |
-
raise HTTPException(status_code=400, detail=reason or "Model unavailable")
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
new_service = ModelService(model_name, PROJECT_ROOT)
|
| 70 |
-
warmup = new_service.predict("warmup")
|
| 71 |
-
if warmup.get("error"):
|
| 72 |
-
state["service"] = prev_service
|
| 73 |
-
state["model_name"] = prev_name
|
| 74 |
-
raise HTTPException(status_code=400, detail=str(warmup["error"]))
|
| 75 |
-
|
| 76 |
-
state["service"] = new_service
|
| 77 |
-
state["model_name"] = model_name
|
| 78 |
-
return {"message": f"Active model set to '{model_name}'", "model": model_name}
|
|
|
|
| 1 |
import time
|
| 2 |
+
from urllib.parse import unquote
|
| 3 |
|
| 4 |
from fastapi import APIRouter, HTTPException
|
| 5 |
|
| 6 |
+
from src.api.schemas import ModelInfo, ModelStatusEntry, ModelsStatusResponse, SelectModelRequest
|
| 7 |
from src.api.services import get_service
|
| 8 |
from src.api.state import PROJECT_ROOT, get_state
|
| 9 |
from src.service.model_service import AVAILABLE_MODELS, ModelService, check_model_availability
|
|
|
|
| 11 |
router = APIRouter(tags=["Model"])
|
| 12 |
|
| 13 |
|
| 14 |
+
def _switch_model_impl(model_name: str) -> dict[str, str]:
|
| 15 |
+
if model_name not in AVAILABLE_MODELS:
|
| 16 |
+
raise HTTPException(
|
| 17 |
+
status_code=400,
|
| 18 |
+
detail=f"Model '{model_name}' not available. Options: {list(AVAILABLE_MODELS.keys())}",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
available, reason = check_model_availability(model_name, PROJECT_ROOT)
|
| 22 |
+
if not available:
|
| 23 |
+
raise HTTPException(status_code=400, detail=reason or "Model unavailable")
|
| 24 |
+
|
| 25 |
+
state = get_state()
|
| 26 |
+
prev_service = state["service"]
|
| 27 |
+
prev_name = state["model_name"]
|
| 28 |
+
|
| 29 |
+
new_service = ModelService(model_name, PROJECT_ROOT)
|
| 30 |
+
warmup = new_service.predict("warmup")
|
| 31 |
+
if warmup.get("error"):
|
| 32 |
+
state["service"] = prev_service
|
| 33 |
+
state["model_name"] = prev_name
|
| 34 |
+
raise HTTPException(status_code=400, detail=str(warmup["error"]))
|
| 35 |
+
|
| 36 |
+
state["service"] = new_service
|
| 37 |
+
state["model_name"] = model_name
|
| 38 |
+
return {"message": f"Active model set to '{model_name}'", "model": model_name}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
@router.get("/model-info", response_model=ModelInfo)
|
| 42 |
async def get_model_info():
|
| 43 |
service = get_service()
|
|
|
|
| 78 |
return {"available": list(AVAILABLE_MODELS.keys()), "active": state["model_name"]}
|
| 79 |
|
| 80 |
|
| 81 |
+
@router.post("/models/select")
|
| 82 |
+
async def select_model(body: SelectModelRequest):
|
| 83 |
+
"""Switch active model (preferred — avoids URL-encoding issues in model names)."""
|
| 84 |
+
return _switch_model_impl(body.model_name.strip())
|
|
|
|
|
|
|
|
|
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
@router.put("/model/{model_name:path}")
|
| 88 |
+
async def switch_model(model_name: str):
|
| 89 |
+
"""Legacy path-based switch (decoded path segment)."""
|
| 90 |
+
return _switch_model_impl(unquote(model_name).strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/api/schemas.py
CHANGED
|
@@ -69,6 +69,10 @@ class ModelsStatusResponse(BaseModel):
|
|
| 69 |
active: str
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
class ModelInfo(BaseModel):
|
| 73 |
name: str
|
| 74 |
type: str
|
|
|
|
| 69 |
active: str
|
| 70 |
|
| 71 |
|
| 72 |
+
class SelectModelRequest(BaseModel):
|
| 73 |
+
model_name: str = Field(..., min_length=1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
class ModelInfo(BaseModel):
|
| 77 |
name: str
|
| 78 |
type: str
|
src/service/model_service.py
CHANGED
|
@@ -14,6 +14,44 @@ from src.service.model_catalog import load_model_catalog
|
|
| 14 |
AVAILABLE_MODELS: dict[str, dict[str, Any]] = load_model_catalog()
|
| 15 |
|
| 16 |
_HF_DEPS_MSG = "Install HF deps: uv sync --extra hf"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def hf_deps_available() -> bool:
|
|
@@ -49,9 +87,12 @@ def check_model_availability(name: str, project_root: Path | None = None) -> tup
|
|
| 49 |
if not hf_deps_available():
|
| 50 |
return False, _HF_DEPS_MSG
|
| 51 |
path = root / cfg["model_path"]
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
if model_type == "hf_remote":
|
| 57 |
if not hf_deps_available():
|
|
@@ -118,9 +159,13 @@ class ModelService:
|
|
| 118 |
self._load_hf(self.cfg["model_id"])
|
| 119 |
elif t == "hf_local":
|
| 120 |
path = self.project_root / self.cfg["model_path"]
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
return self._model
|
| 125 |
|
| 126 |
def _load_local(self) -> None:
|
|
|
|
| 14 |
AVAILABLE_MODELS: dict[str, dict[str, Any]] = load_model_catalog()
|
| 15 |
|
| 16 |
_HF_DEPS_MSG = "Install HF deps: uv sync --extra hf"
|
| 17 |
+
_LFS_POINTER_PREFIX = "version https://git-lfs"
|
| 18 |
+
_MIN_LOCAL_HF_WEIGHTS_BYTES = 1_000_000
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _is_lfs_pointer_file(path: Path) -> bool:
|
| 22 |
+
try:
|
| 23 |
+
if path.stat().st_size > 4096:
|
| 24 |
+
return False
|
| 25 |
+
head = path.read_text(encoding="utf-8", errors="ignore")[:80]
|
| 26 |
+
return head.startswith(_LFS_POINTER_PREFIX)
|
| 27 |
+
except OSError:
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def local_hf_weights_ok(model_dir: Path) -> tuple[bool, str | None]:
|
| 32 |
+
"""Verify a local HF folder has real weight files (not Git LFS pointers)."""
|
| 33 |
+
if not model_dir.is_dir():
|
| 34 |
+
return False, f"Model not found at {model_dir}."
|
| 35 |
+
|
| 36 |
+
for weights_name in ("model.safetensors", "pytorch_model.bin"):
|
| 37 |
+
weights = model_dir / weights_name
|
| 38 |
+
if not weights.is_file():
|
| 39 |
+
continue
|
| 40 |
+
if _is_lfs_pointer_file(weights):
|
| 41 |
+
return False, (
|
| 42 |
+
"Weights missing (Git LFS pointer only). "
|
| 43 |
+
"Run: uv run python scripts/materialize_finetuned_weights.py "
|
| 44 |
+
"(or: brew install git-lfs && git lfs pull)"
|
| 45 |
+
)
|
| 46 |
+
size = weights.stat().st_size
|
| 47 |
+
if size < _MIN_LOCAL_HF_WEIGHTS_BYTES:
|
| 48 |
+
return False, (
|
| 49 |
+
f"{weights_name} is too small ({size} bytes). "
|
| 50 |
+
"Run: uv run python scripts/materialize_finetuned_weights.py"
|
| 51 |
+
)
|
| 52 |
+
return True, None
|
| 53 |
+
|
| 54 |
+
return False, "No model.safetensors or pytorch_model.bin in model directory."
|
| 55 |
|
| 56 |
|
| 57 |
def hf_deps_available() -> bool:
|
|
|
|
| 87 |
if not hf_deps_available():
|
| 88 |
return False, _HF_DEPS_MSG
|
| 89 |
path = root / cfg["model_path"]
|
| 90 |
+
ok, reason = local_hf_weights_ok(path)
|
| 91 |
+
if ok:
|
| 92 |
+
return True, None
|
| 93 |
+
if cfg.get("hub_fallback"):
|
| 94 |
+
return True, reason
|
| 95 |
+
return False, reason
|
| 96 |
|
| 97 |
if model_type == "hf_remote":
|
| 98 |
if not hf_deps_available():
|
|
|
|
| 159 |
self._load_hf(self.cfg["model_id"])
|
| 160 |
elif t == "hf_local":
|
| 161 |
path = self.project_root / self.cfg["model_path"]
|
| 162 |
+
ok, _reason = local_hf_weights_ok(path)
|
| 163 |
+
if ok:
|
| 164 |
+
self._load_hf(str(path))
|
| 165 |
+
elif self.cfg.get("hub_fallback"):
|
| 166 |
+
self._load_hf(self.cfg["hub_fallback"])
|
| 167 |
+
else:
|
| 168 |
+
raise FileNotFoundError(_reason or f"Model not found at {path}.")
|
| 169 |
return self._model
|
| 170 |
|
| 171 |
def _load_local(self) -> None:
|
tests/test_api.py
CHANGED
|
@@ -99,6 +99,29 @@ def test_predict_video_demo_comments_differ_by_url(client: TestClient, monkeypat
|
|
| 99 |
assert data1["results"][0]["text"] != data2["results"][0]["text"]
|
| 100 |
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def test_models_status_lists_catalog(client: TestClient):
|
| 103 |
response = client.get("/models/status")
|
| 104 |
assert response.status_code == 200
|
|
|
|
| 99 |
assert data1["results"][0]["text"] != data2["results"][0]["text"]
|
| 100 |
|
| 101 |
|
| 102 |
+
def test_finetuned_local_reports_lfs_when_pointer_only():
|
| 103 |
+
from src.api.state import PROJECT_ROOT
|
| 104 |
+
from src.service.model_service import check_model_availability
|
| 105 |
+
|
| 106 |
+
weights = PROJECT_ROOT / "models" / "finetuned_hf" / "model.safetensors"
|
| 107 |
+
if not weights.is_file() or weights.stat().st_size >= 4096:
|
| 108 |
+
pytest.skip("finetuned_hf weights present or missing — LFS pointer test N/A")
|
| 109 |
+
|
| 110 |
+
ok, reason = check_model_availability("Fine-tuned (local HF)", PROJECT_ROOT)
|
| 111 |
+
assert ok is False
|
| 112 |
+
assert reason is not None
|
| 113 |
+
assert "materialize" in reason.lower() or "lfs" in reason.lower()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_select_model_via_post(client: TestClient):
|
| 117 |
+
response = client.post(
|
| 118 |
+
"/models/select",
|
| 119 |
+
json={"model_name": "LR + TF-IDF (local)"},
|
| 120 |
+
)
|
| 121 |
+
assert response.status_code == 200
|
| 122 |
+
assert response.json()["model"] == "LR + TF-IDF (local)"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
def test_models_status_lists_catalog(client: TestClient):
|
| 126 |
response = client.get("/models/status")
|
| 127 |
assert response.status_code == 200
|