polyscriptor-htr-demo / web /polyscriptor_server.py
Achim Rabus
Deploy Polyscriptor HTR Space demo
78431ff
"""
Polyscriptor Web UI — FastAPI Backend
Thin wrapper around existing HTR engine code. Provides REST API + SSE
for browser-based transcription. All heavy lifting done by the same
modules the PyQt6 GUI uses.
Usage:
source htr_gui/bin/activate
python -m uvicorn web.polyscriptor_server:app --host 0.0.0.0 --port 8765
Author: Claude Code
Date: 2026-02-26
"""
import asyncio
import hashlib
import importlib
import json
import logging
import os
import sys
import time
import uuid
from dataclasses import dataclass, field
from types import SimpleNamespace
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
from PIL import Image, ImageOps
from fastapi import Cookie, FastAPI, File, HTTPException, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
log = logging.getLogger("polyscriptor")
DEMO_MODE = os.environ.get("POLYSCRIPTOR_DEMO_MODE", "").strip().lower()
# Add project root to path so we can import existing modules
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# Load .env from project root (same as the Qt GUI does via CommercialAPIEngine)
try:
from dotenv import load_dotenv
_env_path = PROJECT_ROOT / ".env"
if _env_path.exists():
load_dotenv(_env_path)
log.info(f"Loaded environment variables from {_env_path}")
except ImportError:
pass # python-dotenv not installed — env vars must be set externally
from htr_engine_base import get_global_registry, HTREngine, TranscriptionResult
# PDF support via PyMuPDF
try:
import fitz as _fitz # PyMuPDF
PDF_AVAILABLE = True
except ImportError:
PDF_AVAILABLE = False
log.warning("PyMuPDF not installed — PDF upload disabled. Install with: pip install pymupdf")
# Lazy imports for segmentation (avoid slow startup)
_segmenters_imported = False
def _import_segmenters():
global _segmenters_imported
if _segmenters_imported:
return
global KrakenLineSegmenter, LineSegmenter, PYLAIA_MODELS
from kraken_segmenter import KrakenLineSegmenter
from inference_page import LineSegmenter
try:
from inference_pylaia_native import PYLAIA_MODELS
except ImportError:
PYLAIA_MODELS = {}
_segmenters_imported = True
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = FastAPI(title="Polyscriptor HTR", version="0.1.0")
# Serve static frontend files
STATIC_DIR = Path(__file__).parent / "static"
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
# ---------------------------------------------------------------------------
# Engine pool — Phase 2: shared pool of loaded engine instances
# ---------------------------------------------------------------------------
@dataclass
class EngineSlot:
"""One loaded engine instance in the pool."""
engine: Any # HTREngine instance (not the registry singleton)
engine_name: str
config: dict
pool_key: str
ref_count: int = 0
last_used: float = field(default_factory=time.time)
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
engine_pool: Dict[str, EngineSlot] = {}
pool_lock = asyncio.Lock()
# VRAM budget estimates (GB) for eviction decisions
_ENGINE_VRAM_GB = {
"CRNN-CTC (PyLaia-inspired)": 2,
"TrOCR": 3,
"Qwen3-VL": 18,
"Churro VLM": 10,
"Kraken": 2,
"Party": 4,
"PaddleOCR": 2,
}
_NO_GPU_ENGINES = {"Commercial APIs", "OpenWebUI", "LightOnOCR", "DeepSeek-OCR"}
_TOTAL_VRAM_GB = 92 # 2x L40S @ 46GB each
# Factory: engine name -> (module, class) for creating fresh instances
_ENGINE_FACTORY = {
"TrOCR": ("engines.trocr_engine", "TrOCREngine"),
"CRNN-CTC (PyLaia-inspired)": ("engines.pylaia_engine", "PyLaiaEngine"),
"Qwen3-VL": ("engines.qwen3_engine", "Qwen3Engine"),
"Churro VLM": ("engines.churro_engine", "ChurroEngine"),
"Kraken": ("engines.kraken_engine", "KrakenEngine"),
"Commercial APIs": ("engines.commercial_api_engine", "CommercialAPIEngine"),
"Party": ("engines.party_engine", "PartyEngine"),
"OpenWebUI": ("engines.openwebui_engine", "OpenWebUIEngine"),
"DeepSeek-OCR": ("engines.deepseek_ocr_engine", "DeepSeekOCREngine"),
"LightOnOCR": ("engines.lighton_ocr_engine", "LightOnOCREngine"),
"PaddleOCR": ("engines.paddle_engine", "PaddleOCREngine"),
}
def _create_engine_instance(engine_name: str):
"""Create a fresh engine instance (not the registry singleton).
The registry is used for discovery/availability only.
Pool slots get their own instances so multiple models can coexist.
"""
entry = _ENGINE_FACTORY.get(engine_name)
if not entry:
return None
module_name, class_name = entry
mod = importlib.import_module(module_name)
cls = getattr(mod, class_name)
return cls()
def _make_pool_key(engine_name: str, config: dict) -> str:
"""Build a key that uniquely identifies an engine+model combination."""
if engine_name == "Commercial APIs":
provider = config.get("provider", "unknown")
model = config.get("model", "unknown")
api_key = config.get("api_key", "")
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:8] if api_key else "nokey"
return f"{engine_name}::{provider}::{model}::{key_hash}"
if engine_name == "OpenWebUI":
model = config.get("model", "unknown")
base_url = config.get("base_url", "unknown")
api_key = config.get("api_key", "")
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:8] if api_key else "nokey"
return f"{engine_name}::{base_url}::{model}::{key_hash}"
if engine_name == "TrOCR":
return f"{engine_name}::{config.get('model_path', 'default')}"
if engine_name in ("CRNN-CTC (PyLaia-inspired)", "Kraken"):
return f"{engine_name}::{config.get('model_path', 'default')}"
if engine_name == "Qwen3-VL":
base = config.get("base_model", "default")
adapter = config.get("adapter", "")
return f"{engine_name}::{base}::{adapter or 'none'}"
if engine_name == "Churro VLM":
return f"{engine_name}::{config.get('model_name', 'default')}"
if engine_name == "LightOnOCR":
return f"{engine_name}::{config.get('model_path', 'default')}"
# Fallback: hash the config
config_hash = hashlib.sha256(str(sorted(config.items())).encode()).hexdigest()[:12]
return f"{engine_name}::{config_hash}"
async def _maybe_evict(new_engine_name: str):
"""Evict LRU slots with ref_count==0 if VRAM is tight. Called UNDER pool_lock."""
if new_engine_name in _NO_GPU_ENGINES:
return
needed = _ENGINE_VRAM_GB.get(new_engine_name, 4)
used = sum(_ENGINE_VRAM_GB.get(s.engine_name, 4)
for s in engine_pool.values()
if s.engine_name not in _NO_GPU_ENGINES)
if used + needed <= _TOTAL_VRAM_GB:
return
# Evict: ref_count==0, oldest first
candidates = sorted(
[(k, s) for k, s in engine_pool.items()
if s.ref_count == 0 and s.engine_name not in _NO_GPU_ENGINES],
key=lambda x: x[1].last_used
)
for key, slot in candidates:
if used + needed <= _TOTAL_VRAM_GB:
break
log.info(f"Evicting engine slot '{key}' (last used {time.time() - slot.last_used:.0f}s ago)")
try:
slot.engine.unload_model()
except Exception as e:
log.warning(f"Error unloading evicted engine: {e}")
del engine_pool[key]
used -= _ENGINE_VRAM_GB.get(slot.engine_name, 4)
if used + needed > _TOTAL_VRAM_GB:
log.warning(f"VRAM tight: ~{used}GB used + ~{needed}GB needed > {_TOTAL_VRAM_GB}GB total")
# Compatibility shims — will be removed after full migration
loaded_engine: Optional[HTREngine] = None
loaded_engine_name: str = ""
loaded_config: dict = {}
# Persistent upload storage (survives server restarts)
UPLOAD_DIR = Path(__file__).parent / "uploads"
UPLOAD_DIR.mkdir(exist_ok=True)
# Upload TTL: 24 hours
_UPLOAD_TTL_SECONDS = 86400
# Session TTL: 2 hours of inactivity
_SESSION_TTL_SECONDS = 7200
# Cookie name for session tracking
_SESSION_COOKIE = "polyscriptor_session"
# ---------------------------------------------------------------------------
# Per-user sessions — Phase 1 of multi-user refactoring
# ---------------------------------------------------------------------------
@dataclass
class UserSession:
session_id: str
image_cache: Dict[str, dict] = field(default_factory=dict)
cancel_events: Dict[str, asyncio.Event] = field(default_factory=dict)
pool_key: Optional[str] = None # Reference into engine_pool
created_at: float = field(default_factory=time.time)
last_active: float = field(default_factory=time.time)
sessions: Dict[str, UserSession] = {}
global_image_cache: Dict[str, dict] = {}
def _get_or_create_session(session_id: Optional[str]) -> tuple[UserSession, bool]:
"""Return (session, created). If session_id is missing/unknown, create a new one."""
if session_id and session_id in sessions:
session = sessions[session_id]
session.last_active = time.time()
return session, False
new_id = str(uuid.uuid4())
session = UserSession(session_id=new_id)
sessions[new_id] = session
return session, True
def _cleanup_expired_sessions() -> int:
"""Remove sessions inactive for more than _SESSION_TTL_SECONDS. Returns count removed."""
cutoff = time.time() - _SESSION_TTL_SECONDS
expired = [sid for sid, s in sessions.items() if s.last_active < cutoff]
for sid in expired:
session = sessions.pop(sid)
# Release pool reference
if session.pool_key and session.pool_key in engine_pool:
slot = engine_pool[session.pool_key]
slot.ref_count = max(0, slot.ref_count - 1)
if slot.ref_count == 0:
log.info(f"Immediate eviction (session expiry): '{slot.engine_name}'")
try:
slot.engine.unload_model()
except Exception as e:
log.warning(f"unload_model() failed for '{slot.engine_name}': {e}")
if session.pool_key in engine_pool:
del engine_pool[session.pool_key]
# Clean up upload files belonging to this session
for iid, img_data in session.image_cache.items():
p = img_data.get("path")
if p:
Path(p).unlink(missing_ok=True)
xp = img_data.get("xml_path")
if xp:
Path(xp).unlink(missing_ok=True)
log.info(f"Expired session {sid[:8]}... ({len(session.image_cache)} images)")
return len(expired)
_SESSION_PASSTHROUGH_PATHS = {"/api/gpu", "/api/engines", "/api/kraken/presets"}
@app.middleware("http")
async def session_middleware(request: Request, call_next):
"""Inject session into request.state; set session cookie on new sessions.
Pure status/discovery routes (GPU poll, engine list) are excluded from
last_active updates so that background browser polling cannot keep a session
alive indefinitely and prevent engine-slot eviction.
"""
session_id = request.cookies.get(_SESSION_COOKIE)
session, created = _get_or_create_session(session_id)
request.state.session = session
# Don't update last_active for polling-only routes
if request.url.path in _SESSION_PASSTHROUGH_PATHS:
session.last_active # read only — no write
else:
session.last_active = time.time()
response = await call_next(request)
if created or session_id != session.session_id:
cookie_kwargs = {
"key": _SESSION_COOKIE,
"value": session.session_id,
"httponly": True,
"max_age": _SESSION_TTL_SECONDS,
}
if DEMO_MODE == "hf_space":
cookie_kwargs.update({"samesite": "none", "secure": True})
else:
cookie_kwargs.update({"samesite": "lax"})
response.set_cookie(
**cookie_kwargs
)
return response
def _get_session(request: Request) -> UserSession:
"""FastAPI dependency: extract session set by middleware."""
return request.state.session
def _cleanup_old_uploads() -> int:
"""Delete uploads older than TTL and evict image_cache entries across all sessions."""
cutoff = time.time() - _UPLOAD_TTL_SECONDS
deleted = 0
for f in list(UPLOAD_DIR.iterdir()):
if f.is_file():
try:
if f.stat().st_mtime < cutoff:
f.unlink(missing_ok=True)
deleted += 1
except OSError:
pass
# Evict stale image_cache entries whose file no longer exists (all sessions)
for session in sessions.values():
for iid in list(session.image_cache.keys()):
p = session.image_cache[iid].get("path")
if p and not Path(p).exists():
del session.image_cache[iid]
return deleted
_SLOT_IDLE_TTL_SECONDS = 6 * 3600 # evict loaded engines idle for 6h, regardless of ref_count
def _evict_idle_slots() -> int:
"""Evict engine slots that have not been used for _SLOT_IDLE_TTL_SECONDS.
Called under no lock — must only be called from _periodic_cleanup (single-threaded).
The GPU-status poll (/api/gpu) keeps sessions alive indefinitely, so we cannot rely
on session expiry alone to release VRAM. This independently caps engine residency.
"""
cutoff = time.time() - _SLOT_IDLE_TTL_SECONDS
stale = [k for k, s in engine_pool.items() if s.last_used < cutoff
and s.engine_name not in _NO_GPU_ENGINES]
for key in stale:
slot = engine_pool.pop(key)
log.info(f"Idle eviction: '{slot.engine_name}' (idle {(time.time() - slot.last_used)/3600:.1f}h)")
try:
slot.engine.unload_model()
except Exception as e:
log.warning(f"unload_model() failed for '{slot.engine_name}': {e}")
# Invalidate all sessions pointing at this slot
for session in sessions.values():
if session.pool_key == key:
session.pool_key = None
return len(stale)
async def _periodic_cleanup():
"""Background task: clean up uploads + expired sessions + idle engine slots every hour."""
while True:
await asyncio.sleep(3600)
n = _cleanup_old_uploads()
m = _cleanup_expired_sessions()
p = _evict_idle_slots()
if n or m or p:
log.info(f"Periodic cleanup: {n} upload(s), {m} session(s), {p} idle engine slot(s).")
# ---------------------------------------------------------------------------
# API key resolution — keys never stored or shared server-side (Phase 3)
# Web UI users MUST provide their own keys via browser localStorage.
# Server env vars (.env) are NOT used by the web UI — they exist only for
# the PyQt GUI and CLI tools which run locally on the admin's machine.
# ---------------------------------------------------------------------------
# Known key slots (for validation only — env vars are NOT consulted)
_KEY_SLOTS = {"openai", "gemini", "claude", "openwebui"}
def _resolve_api_key(slot: str, request_value: str) -> str:
"""
Return the API key from the browser request, or empty string.
Server env vars are deliberately NOT used as fallback — each web user
must supply their own key via browser localStorage.
"""
if request_value and request_value.strip():
return request_value.strip()
return ""
# ---------------------------------------------------------------------------
# Startup config (web/server_config.yaml) — optional, auto-load an engine
# ---------------------------------------------------------------------------
def _load_startup_config() -> dict:
cfg_path = Path(__file__).parent / "server_config.yaml"
if not cfg_path.exists():
return {}
try:
import yaml
with open(cfg_path) as f:
return yaml.safe_load(f) or {}
except Exception as e:
log.warning(f"Could not read server_config.yaml: {e}")
return {}
@app.on_event("startup")
async def startup_event():
"""Clean old uploads, start periodic cleanup, auto-load engine."""
# Clean up uploads left over from previous server runs
n = _cleanup_old_uploads()
if n:
log.info(f"Startup cleanup: removed {n} old upload file(s).")
# Schedule periodic cleanup (every hour)
asyncio.create_task(_periodic_cleanup())
# Auto-load default engine from server_config.yaml if present
cfg = _load_startup_config()
if not cfg.get("default_engine"):
return
engine_name = cfg["default_engine"]
engine_config = cfg.get("default_config", {})
log.info(f"Auto-loading engine '{engine_name}' from server_config.yaml ...")
try:
registry = get_global_registry()
reg_engine = registry.get_engine_by_name(engine_name)
if reg_engine and reg_engine.is_available():
engine = _create_engine_instance(engine_name)
if not engine:
log.warning(f"Auto-load: cannot create instance for '{engine_name}'.")
return
ok = await asyncio.to_thread(engine.load_model, engine_config)
if ok:
pool_key = _make_pool_key(engine_name, engine_config)
slot = EngineSlot(
engine=engine, engine_name=engine_name,
config=engine_config, pool_key=pool_key,
ref_count=0, # No session owns it yet
)
engine_pool[pool_key] = slot
# Update compat shims
global loaded_engine, loaded_engine_name, loaded_config
loaded_engine = engine
loaded_engine_name = engine_name
loaded_config = engine_config
log.info(f"Auto-loaded '{engine_name}' into pool as '{pool_key}'.")
else:
log.warning(f"Auto-load of '{engine_name}' failed (load_model returned False).")
else:
log.warning(f"Auto-load: engine '{engine_name}' not found or not available.")
except Exception as e:
log.warning(f"Auto-load error: {e}")
# ---------------------------------------------------------------------------
# Config schemas — replaces Qt config widgets for the web UI
# ---------------------------------------------------------------------------
def _get_pylaia_model_options() -> list:
_import_segmenters()
from inference_pylaia_native import _scan_pylaia_models
_scan_pylaia_models(str(Path(__file__).resolve().parents[1] / "models"))
options = [{"label": k, "value": k} for k in PYLAIA_MODELS.keys()]
options.append({"label": "Custom / local path…", "value": "__custom__"})
return options
def _scan_kraken_models() -> list:
"""Scan models/ directory for local Kraken .mlmodel files and build select options."""
options = []
models_root = Path(__file__).resolve().parents[1] / "models"
if models_root.exists():
for p in sorted(models_root.rglob("*.mlmodel")):
rel = str(p.relative_to(models_root.parent)) # e.g. models/kraken_cs/best.mlmodel
label = f"{p.parent.name}/{p.name}"
options.append({"label": label, "value": rel, "source": "local"})
# Zenodo presets from kraken_engine (auto-download on load)
try:
from engines.kraken_engine import KRAKEN_MODELS
for preset_id, info in KRAKEN_MODELS.items():
if info.get("source") == "zenodo":
options.append({
"label": f"{info.get('label', preset_id)} [Zenodo, auto-download]",
"value": f"__zenodo__{preset_id}",
"source": "zenodo",
})
except Exception:
pass
return options
def _scan_trocr_models() -> list:
"""Scan models/ directory for TrOCR checkpoints.
A directory is considered a TrOCR model if it contains
preprocessor_config.json (TrOCR/ViT-specific) AND config.json
with model_type == 'vision-encoder-decoder'.
This avoids picking up PyLaia/CRNN-CTC directories that also
contain a config.json with training parameters.
"""
import json as _json
models_dir = PROJECT_ROOT / "models"
options = [
{"label": "Custom HuggingFace ID or local path…", "value": "__custom__"},
{"label": "kazars24/trocr-base-handwritten-ru (HuggingFace)",
"value": "kazars24/trocr-base-handwritten-ru",
"source": "huggingface"},
{"label": "microsoft/trocr-base-printed — printed text, base",
"value": "microsoft/trocr-base-printed",
"source": "huggingface"},
{"label": "microsoft/trocr-large-printed — printed text, large",
"value": "microsoft/trocr-large-printed",
"source": "huggingface"},
{"label": "dh-unibe/trocr-kurrent — German Kurrent 19th c. (CER 2.66%)",
"value": "dh-unibe/trocr-kurrent",
"source": "huggingface"},
{"label": "dh-unibe/trocr-kurrent-XVI-XVII — German Kurrent 16th–18th c. (CER 5.42%)",
"value": "dh-unibe/trocr-kurrent-XVI-XVII",
"source": "huggingface"},
]
if models_dir.exists():
for d in sorted(models_dir.iterdir()):
if not d.is_dir():
continue
# Require BOTH preprocessor_config.json AND config.json with
# model_type == 'vision-encoder-decoder'.
# preprocessor_config.json is ViT/TrOCR-specific (not in PyLaia).
# config.json model_type disambiguates from Qwen3 adapters that
# also ship a preprocessor_config but have no config.json.
if not (d / "preprocessor_config.json").exists():
continue
cfg_path = d / "config.json"
if not cfg_path.exists():
continue
try:
cfg = _json.load(open(cfg_path))
if cfg.get("model_type") != "vision-encoder-decoder":
continue
except Exception:
continue
options.append({
"label": d.name,
"value": str(d),
"source": "local",
})
return options
def _scan_vlm_models(engine_type: str = "qwen3") -> list:
"""Scan models/ directory for local VLM checkpoints (LoRA adapters and full models).
Looks for directories containing adapter_config.json (LoRA fine-tunes) or
config.json mentioning Qwen/VLM/vision architectures.
Returns options list ending with a __custom__ sentinel for manual entry.
"""
models_dir = PROJECT_ROOT / "models"
options = []
if models_dir.exists():
for d in sorted(models_dir.iterdir()):
if not d.is_dir():
continue
# Check for LoRA adapter at top-level
if (d / "adapter_config.json").exists():
try:
import json as _json
with open(d / "adapter_config.json") as f:
adapter_cfg = _json.load(f)
base = adapter_cfg.get("base_model_name_or_path", "")
is_qwen = "qwen" in base.lower() or "qwen" in d.name.lower()
is_churro = "churro" in base.lower() or "churro" in d.name.lower()
if engine_type == "qwen3" and is_qwen and not is_churro:
options.append({
"label": f"{d.name} (LoRA → {base})",
"value": str(d),
"base_model": base,
"adapter": str(d),
})
elif engine_type == "churro" and (is_churro or ("churro" in d.name.lower())):
options.append({
"label": f"{d.name} (LoRA → {base})",
"value": str(d),
"base_model": base,
"adapter": str(d),
})
except Exception:
pass
continue # Don't also check final_model subdirs
# Check for final_model subdirectory with adapter
final = d / "final_model"
if final.is_dir() and (final / "adapter_config.json").exists():
try:
import json as _json
with open(final / "adapter_config.json") as f:
adapter_cfg = _json.load(f)
base = adapter_cfg.get("base_model_name_or_path", "")
is_qwen = "qwen" in base.lower() or "qwen" in d.name.lower()
is_churro = "churro" in base.lower() or "churro" in d.name.lower()
if engine_type == "qwen3" and is_qwen and not is_churro:
options.append({
"label": f"{d.name} (LoRA → {base})",
"value": str(final),
"base_model": base,
"adapter": str(final),
})
elif engine_type == "churro" and (is_churro or ("churro" in d.name.lower())):
options.append({
"label": f"{d.name} (LoRA → {base})",
"value": str(final),
"base_model": base,
"adapter": str(final),
})
except Exception:
pass
# Always append a "Custom / HuggingFace" sentinel as the last option
options.append({
"label": "Custom / HuggingFace model ID...",
"value": "__custom__",
})
return options
ENGINE_SCHEMAS = {
"CRNN-CTC (PyLaia-inspired)": lambda: {
"fields": [
{"key": "model_path", "type": "select", "label": "Model",
"options": _get_pylaia_model_options(),
"custom_key": "custom_model_path",
"custom_placeholder": "Absolute path to best_model.pt (e.g. /home/…/models/pylaia_yiddish_20260326/best_model.pt)"},
{"key": "enable_spaces", "type": "checkbox",
"label": "Convert <space> tokens", "default": True},
{"key": "flip_rtl", "type": "checkbox",
"label": "RTL manuscript (flip line images)", "default": False,
"hint": "Flip line images horizontally for RTL scripts (Ottoman, Arabic, Hebrew)"},
]
},
"TrOCR": lambda: {
"fields": [
{"key": "model_path", "type": "select", "label": "Model",
"options": _scan_trocr_models(),
"custom_key": "custom_model_path",
"custom_placeholder": "HuggingFace model ID (e.g. microsoft/trocr-base-handwritten) or absolute local path"},
{"key": "num_beams", "type": "number", "label": "Beam Search",
"min": 1, "max": 10, "default": 4},
{"key": "normalize_background", "type": "checkbox",
"label": "Normalize Background", "default": False},
{"key": "flip_rtl", "type": "checkbox",
"label": "RTL manuscript (flip line images)", "default": False,
"hint": "Flip line images horizontally for RTL scripts (Ottoman, Arabic, Hebrew)"},
]
},
"Qwen3-VL": lambda: {
"fields": [
{"key": "model_preset", "type": "select", "label": "Model",
"options": _scan_vlm_models("qwen3"),
"custom_key": "base_model",
"custom_placeholder": "HuggingFace model ID, e.g. Qwen/Qwen3-VL-8B-Instruct"},
{"key": "max_image_size", "type": "number", "label": "Max Image Size (px)",
"min": 512, "max": 4096, "default": 1536},
]
},
"Churro VLM": lambda: {
"fields": [
{"key": "model_preset", "type": "select", "label": "Model",
"options": _scan_vlm_models("churro"),
"custom_key": "model_name",
"custom_placeholder": "HuggingFace model ID, e.g. stanford-oval/churro-3B"},
{"key": "device", "type": "select", "label": "Device",
"options": [{"label": "Auto", "value": "auto"},
{"label": "GPU 0", "value": "cuda:0"},
{"label": "GPU 1", "value": "cuda:1"},
{"label": "CPU", "value": "cpu"}]},
{"key": "max_image_size", "type": "number", "label": "Max Image Size (px)",
"min": 512, "max": 4096, "default": 2048},
]
},
"Kraken": lambda: {
"fields": [
{"key": "model_path", "type": "select", "label": "Model",
"options": _scan_kraken_models(),
"custom_key": "custom_model_path",
"custom_placeholder": "Absolute path on server, e.g. /home/user/models/my.mlmodel",
"upload": True},
]
},
"Commercial APIs": lambda: {
"fields": [
{"key": "provider", "type": "select", "label": "Provider",
"options": [
{"label": "OpenAI (GPT-4o, o1, …)", "value": "OpenAI"},
{"label": "Google Gemini", "value": "Gemini"},
{"label": "Anthropic Claude", "value": "Claude"},
]},
{"key": "model", "type": "select", "label": "Model",
"dynamic": True,
"dynamic_hint": "Enter API key, then ↻ to load available models",
# No static lists — always fetch live from the provider API
"per_provider_options": {},
"options": [],
"custom_key": "custom_model_id",
"custom_placeholder": "e.g. gpt-4.5, gemini-exp-1206, claude-opus-4"},
{"key": "api_key", "type": "password", "label": "API Key",
"default": "", "placeholder": "Paste your API key here"},
{"key": "temperature", "type": "number", "label": "Temperature",
"min": 0.0, "max": 2.0, "default": 0.0,
"placeholder": "0.0 = deterministic (recommended for transcription)"},
{"key": "max_output_tokens", "type": "number", "label": "Max output tokens (optional)",
"min": 512, "max": 65536, "default": None,
"placeholder": "Leave blank = model maximum"},
{"key": "custom_prompt", "type": "textarea", "label": "Custom Prompt (optional)",
"default": "",
"rows": 4,
"placeholder": "Transcribe all handwritten text in this manuscript image. Preserve the original language (Cyrillic, Latin, etc.) and layout. Output only the transcribed text without any additional commentary.",
"hint": "Leave blank to use the default prompt shown above"},
{"key": "thinking_mode", "type": "select", "label": "Thinking Mode (Gemini only)",
"options": [
{"label": "Auto (model decides, no cap)", "value": ""},
{"label": "Low (budget: 8k tokens)", "value": "low"},
{"label": "High (no cap, max reasoning)", "value": "high"},
], "default": ""},
]
},
"OpenWebUI": lambda: {
"fields": [
{"key": "base_url", "type": "text", "label": "Base URL",
"default": "",
"placeholder": "https://your-openwebui-instance/api or .../api/v1"},
{"key": "api_key", "type": "password", "label": "API Key",
"default": "", "placeholder": "Your OpenWebUI API key"},
{"key": "model", "type": "select", "label": "Model",
"dynamic": True,
"dynamic_hint": "Enter API key & base URL, then ↻ to load available models",
"options": [{"label": "Custom model ID…", "value": "__custom__"}],
"default": "__custom__",
"custom_key": "model_custom",
"custom_placeholder": "e.g. llama3.1, qwen2.5vl, gemma3, ..."},
{"key": "temperature", "type": "number", "label": "Temperature",
"min": 0.0, "max": 2.0, "default": 0.1},
{"key": "max_tokens", "type": "number", "label": "Max output tokens (optional)",
"min": 512, "max": 65536, "default": None,
"placeholder": "Leave blank = model maximum"},
{"key": "custom_prompt", "type": "textarea", "label": "Custom Prompt (optional)",
"default": "",
"rows": 3,
"placeholder": "Transcribe all handwritten text in this manuscript image. Preserve the original language (Cyrillic, Latin, etc.) and layout. Output only the transcribed text without any additional commentary.",
"hint": "Leave blank to use the default prompt shown above"},
]
},
"LightOnOCR": lambda: {
"fields": [
{"key": "model_path", "type": "select", "label": "Model",
"options": (lambda: [
{"label": f"{name}{info.get('description','')}", "value": info["id"]}
for name, info in __import__('lighton_models', fromlist=['LIGHTON_MODELS']).LIGHTON_MODELS.items()
] + [{"label": "Custom HuggingFace ID…", "value": "__custom__"}])(),
"custom_key": "custom_model_path",
"custom_placeholder": "HuggingFace model ID, e.g. lightonai/LightOnOCR-2-1B-base"},
{"key": "max_new_tokens", "type": "number", "label": "Max new tokens",
"min": 32, "max": 512, "default": 128},
]
},
"PaddleOCR": lambda: {
"fields": [
{"key": "lang", "type": "select", "label": "Language / Script",
"default": "ch",
"options": [
{"label": "Chinese + English (mixed, recommended default)", "value": "ch"},
{"label": "English", "value": "en"},
{"label": "German", "value": "german"},
{"label": "French", "value": "french"},
{"label": "Japanese", "value": "japan"},
{"label": "Korean", "value": "korean"},
{"label": "Arabic", "value": "arabic"},
{"label": "Cyrillic (Russian/Ukrainian/Bulgarian)", "value": "cyrillic"},
{"label": "Latin script (generic)", "value": "latin"},
{"label": "Custom (enter code below)", "value": "__custom__"},
],
"custom_key": "custom_lang",
"custom_placeholder": "PaddleOCR lang code, e.g. ru, uk, fr, es, it, pt, …",
"hint": "One language model per run. 'ch' is bilingual (Chinese+English) and PaddleOCR's strongest model. For mixed-script documents outside this list, run separate passes."},
{"key": "use_angle_cls", "type": "checkbox",
"label": "Text-angle classifier (correct 180° rotation)", "default": True},
{"key": "use_gpu", "type": "checkbox",
"label": "Use GPU (requires paddlepaddle-gpu)", "default": False},
]
},
}
# ---------------------------------------------------------------------------
# Request/response models
# ---------------------------------------------------------------------------
class EngineLoadRequest(BaseModel):
engine_name: str
config: Dict[str, Any] = {}
class TranscribeRequest(BaseModel):
image_id: str
seg_method: str = "kraken" # kraken, kraken-blla, hpp
seg_device: str = "cpu"
max_columns: int = 6 # blla: max sub-columns per region (iterative splitting)
split_width_fraction: float = 0.40 # blla: min region width (fraction of page) to trigger sub-split
use_pagexml: bool = True # use attached PAGE XML for segmentation when available
text_direction: str = "horizontal-lr" # reading order for Kraken: horizontal-lr, horizontal-rl, vertical-lr, vertical-rl
engine_config_overrides: Dict[str, Any] = {} # live form values merged into stored config at transcription time
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/")
async def index():
return FileResponse(str(STATIC_DIR / "index.html"))
@app.get("/demo")
async def pwa_demo():
return FileResponse(str(STATIC_DIR / "pwa" / "demo.html"))
@app.get("/manifest.json")
async def pwa_manifest():
"""Serve the PWA manifest from root so scope / start_url are valid."""
from fastapi.responses import FileResponse as _FR
return _FR(str(STATIC_DIR / "pwa" / "manifest.json"), media_type="application/manifest+json")
@app.get("/sw.js")
async def pwa_service_worker():
"""Serve the PWA service worker from root scope so it can control /demo."""
from fastapi.responses import FileResponse as _FR
resp = _FR(str(STATIC_DIR / "pwa" / "sw.js"), media_type="application/javascript")
resp.headers["Service-Worker-Allowed"] = "/"
return resp
@app.get("/api/engines")
async def list_engines():
registry = get_global_registry()
engines = []
for engine in registry.get_all_engines():
available = engine.is_available()
engines.append({
"name": engine.get_name(),
"description": engine.get_description(),
"available": available,
"unavailable_reason": engine.get_unavailable_reason() if not available else None,
"requires_line_segmentation": engine.requires_line_segmentation(),
"has_config_schema": engine.get_name() in ENGINE_SCHEMAS,
})
return engines
@app.get("/api/engine/{name}/config-schema")
async def get_config_schema(name: str):
if name not in ENGINE_SCHEMAS:
return {"fields": []}
schema = ENGINE_SCHEMAS[name]()
# Key status: always "missing" from server perspective — browser localStorage
# is the only key store. The frontend checks localStorage client-side.
for field in schema.get("fields", []):
if field.get("type") == "password":
field["key_status"] = "missing"
return schema
def _openwebui_model_urls(base_url: str) -> list[str]:
base = base_url.strip().rstrip("/")
if not base:
return []
urls = [f"{base}/models"]
if base.endswith("/api"):
urls.append(f"{base}/v1/models")
urls.append(f"{base[:-4]}/v1/models")
elif base.endswith("/api/v1"):
urls.append(f"{base[:-3]}/models")
urls.append(f"{base}/models")
elif base.endswith("/v1"):
urls.append(f"{base[:-3]}/api/models")
else:
urls.append(f"{base}/api/models")
urls.append(f"{base}/api/v1/models")
urls.append(f"{base}/v1/models")
return list(dict.fromkeys(urls))
def _extract_openwebui_model_ids(payload: Any) -> list[str]:
if isinstance(payload, dict):
for key in ("data", "models"):
items = payload.get(key)
if isinstance(items, list):
return _extract_openwebui_model_ids(items)
return [
str(value.get("id") or value.get("name"))
for value in payload.values()
if isinstance(value, dict) and (value.get("id") or value.get("name"))
]
if isinstance(payload, list):
models = []
for item in payload:
if isinstance(item, str):
models.append(item)
elif isinstance(item, dict):
model_id = item.get("id") or item.get("name") or item.get("model")
if model_id:
models.append(str(model_id))
return sorted(set(models))
return []
def _fetch_openwebui_models(base_url: str, api_key: str) -> list[str]:
import urllib.error
import urllib.request
errors = []
for url in _openwebui_model_urls(base_url):
req = urllib.request.Request(
url,
headers={
"Authorization": f"Bearer {api_key}",
"x-api-key": api_key,
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": "Polyscriptor-HTR-Demo/1.0",
},
)
try:
with urllib.request.urlopen(req, timeout=20) as resp:
status = resp.status
content_type = resp.headers.get("Content-Type", "")
body = resp.read().decode("utf-8", errors="replace")
try:
payload = json.loads(body)
except json.JSONDecodeError:
sample = body.strip().replace("\n", " ")[:120] or "<empty response>"
errors.append(f"{url}: HTTP {status}, non-JSON response ({content_type}): {sample}")
continue
models = _extract_openwebui_model_ids(payload)
if models:
return models
errors.append(f"{url}: no model ids in response")
except urllib.error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")[:200]
errors.append(f"{url}: HTTP {exc.code} {body}")
except Exception as exc:
errors.append(f"{url}: {exc}")
raise RuntimeError("; ".join(errors) if errors else "No OpenWebUI model endpoint tried")
@app.get("/api/engine/status")
async def engine_status(request: Request):
session = _get_session(request)
if session.pool_key and session.pool_key in engine_pool:
slot = engine_pool[session.pool_key]
return {
"loaded": slot.engine.is_model_loaded(),
"engine_name": slot.engine_name,
"config": slot.config,
}
# Fallback: compat shim for tests / startup
return {
"loaded": loaded_engine is not None and loaded_engine.is_model_loaded(),
"engine_name": loaded_engine_name,
"config": loaded_config,
}
@app.get("/api/engine/{name}/models")
async def get_engine_models(
name: str,
api_key: str = "",
provider: str = "openai",
base_url: str = "",
):
"""
Fetch available models for engines whose model list is dynamic.
- OpenWebUI: queries the OpenWebUI /api/models endpoint
- Commercial APIs: uses existing fetch_* helpers with fallback lists
"""
if name == "OpenWebUI":
resolved = _resolve_api_key("openwebui", api_key)
if not resolved:
return {"models": [], "error": "No API key — paste one in the form"}
effective_url = base_url.strip().rstrip("/")
if not effective_url:
return {"models": [], "error": "Enter your OpenWebUI base URL"}
try:
models = await asyncio.to_thread(_fetch_openwebui_models, effective_url, resolved)
return {"models": models}
except Exception as e:
return {"models": [], "error": str(e)}
elif name == "Commercial APIs":
prov = provider.lower()
resolved = _resolve_api_key(prov, api_key)
if not resolved:
return {"models": [], "error": "No API key — paste one in the form"}
try:
sys.path.insert(0, str(PROJECT_ROOT))
if prov == "openai":
from inference_commercial_api import fetch_openai_models
models = await asyncio.to_thread(fetch_openai_models, resolved)
return {"models": models}
elif prov == "gemini":
from inference_commercial_api import fetch_gemini_models
models = await asyncio.to_thread(fetch_gemini_models, resolved)
return {"models": models}
elif prov == "claude":
from inference_commercial_api import fetch_claude_models
models = await asyncio.to_thread(fetch_claude_models, resolved)
return {"models": models}
else:
return {"models": [], "error": f"Unknown provider: {provider}"}
except Exception as e:
return {"models": [], "error": str(e)}
return {"models": [], "error": f"Dynamic model listing not supported for '{name}'"}
@app.post("/api/engine/load")
async def load_engine(request: Request, req: EngineLoadRequest):
global loaded_engine, loaded_engine_name, loaded_config
session = _get_session(request)
registry = get_global_registry()
reg_engine = registry.get_engine_by_name(req.engine_name)
if not reg_engine:
raise HTTPException(404, f"Engine '{req.engine_name}' not found")
if not reg_engine.is_available():
raise HTTPException(400, f"Engine not available: {reg_engine.get_unavailable_reason()}")
# --- Config resolution (unchanged logic) ---
config = dict(req.config)
if req.engine_name == "CRNN-CTC (PyLaia-inspired)" and "model_path" in config:
custom_val = config.pop("custom_model_path", "").strip()
if config["model_path"] == "__custom__":
if not custom_val:
raise HTTPException(400, "Please enter an absolute path to a best_model.pt file")
config["model_path"] = custom_val
# else: named preset from PYLAIA_MODELS — engine resolves it
elif req.engine_name == "Kraken" and "model_path" in config:
custom_val = config.pop("custom_model_path", "").strip()
val = config["model_path"]
if val == "__custom__":
if not custom_val:
raise HTTPException(400, "Please enter a path to a local .mlmodel file")
config["model_path"] = custom_val
elif val.startswith("__zenodo__"):
# Zenodo preset: pass preset_id, let engine handle download
config["preset_id"] = val[len("__zenodo__"):]
config["model_path"] = None
# else: relative local path from select (e.g. "models/kraken_cs/best.mlmodel") — use as-is
elif req.engine_name == "TrOCR" and "model_path" in config:
custom_val = config.pop("custom_model_path", "").strip()
if config["model_path"] == "__custom__":
if not custom_val:
raise HTTPException(400, "Please enter a HuggingFace model ID or local path")
config["model_path"] = custom_val
from pathlib import Path as _P
if _P(config["model_path"]).exists():
config["model_source"] = "local"
else:
config["model_source"] = "huggingface"
elif req.engine_name == "Qwen3-VL" and "model_preset" in config:
preset_val = config.pop("model_preset")
custom_val = config.pop("base_model", "").strip()
if preset_val == "__custom__":
config["base_model"] = custom_val or "Qwen/Qwen3-VL-8B-Instruct"
config["adapter"] = None
else:
vlm_opts = _scan_vlm_models("qwen3")
matched = next((o for o in vlm_opts if o["value"] == preset_val), None)
if matched:
config["base_model"] = matched.get("base_model", preset_val)
config["adapter"] = matched.get("adapter")
else:
config["base_model"] = preset_val
config["adapter"] = None
elif req.engine_name == "Churro VLM" and "model_preset" in config:
preset_val = config.pop("model_preset")
custom_val = config.pop("model_name", "").strip()
if preset_val == "__custom__":
config["model_name"] = custom_val or "stanford-oval/churro-3B"
config["adapter_path"] = None
else:
vlm_opts = _scan_vlm_models("churro")
matched = next((o for o in vlm_opts if o["value"] == preset_val), None)
if matched:
config["model_name"] = matched.get("base_model", preset_val)
config["adapter_path"] = matched.get("adapter")
else:
config["model_name"] = preset_val
config["adapter_path"] = None
elif req.engine_name == "LightOnOCR" and "model_path" in config:
custom_val = config.pop("custom_model_path", "").strip()
if config["model_path"] == "__custom__":
if not custom_val:
raise HTTPException(400, "Please enter a HuggingFace model ID for LightOnOCR")
config["model_path"] = custom_val
elif req.engine_name == "PaddleOCR" and "lang" in config:
if config["lang"] == "__custom__":
custom_lang = config.pop("custom_lang", "").strip()
if not custom_lang:
raise HTTPException(400, "Please enter a PaddleOCR language code")
config["lang"] = custom_lang
else:
config.pop("custom_lang", None)
elif req.engine_name == "Commercial APIs":
if config.get("model") == "__custom__":
config["model"] = config.pop("model_custom", "").strip() or "gpt-4o"
elif req.engine_name == "OpenWebUI":
if config.get("model") == "__custom__":
custom_model = config.pop("model_custom", "").strip()
if not custom_model:
raise HTTPException(400, "Please enter an OpenWebUI model ID")
config["model"] = custom_model
# Resolve API keys
if req.engine_name == "Commercial APIs":
provider_slot = config.get("provider", "openai").lower()
raw_key = config.get("api_key", "")
resolved = _resolve_api_key(provider_slot, raw_key)
if not resolved:
raise HTTPException(400, f"No API key for {config.get('provider')}. "
"Paste your API key in the field.")
config["api_key"] = resolved
elif req.engine_name == "OpenWebUI":
base_url = config.get("base_url", "").strip().rstrip("/")
if not base_url:
raise HTTPException(400, "No OpenWebUI base URL. "
"Enter your own OpenWebUI API base URL.")
config["base_url"] = base_url
raw_key = config.get("api_key", "")
resolved = _resolve_api_key("openwebui", raw_key)
if not resolved:
raise HTTPException(400, "No API key for OpenWebUI. "
"Paste your API key in the field.")
config["api_key"] = resolved
# Strip empty custom_prompt for API engines (use engine default)
if req.engine_name in ("Commercial APIs", "OpenWebUI"):
if not config.get("custom_prompt", "").strip():
config["custom_prompt"] = None
# --- Engine pool logic ---
pool_key = _make_pool_key(req.engine_name, config)
async with pool_lock:
# Release previous engine reference for this session
if session.pool_key and session.pool_key in engine_pool:
prev_slot = engine_pool[session.pool_key]
prev_slot.ref_count = max(0, prev_slot.ref_count - 1)
if prev_slot.ref_count == 0:
log.info(f"Immediate eviction (engine switch): '{prev_slot.engine_name}'")
try:
prev_slot.engine.unload_model()
except Exception as e:
log.warning(f"unload_model() failed for '{prev_slot.engine_name}': {e}")
if session.pool_key in engine_pool:
del engine_pool[session.pool_key]
# Check if this exact engine+model is already loaded
if pool_key in engine_pool:
slot = engine_pool[pool_key]
slot.ref_count += 1
slot.last_used = time.time()
session.pool_key = pool_key
# Update compat shims
loaded_engine = slot.engine
loaded_engine_name = slot.engine_name
loaded_config = slot.config
log.info(f"Pool hit: reusing '{pool_key}' (ref_count={slot.ref_count})")
return {"success": True, "load_time_s": 0.0,
"engine_name": req.engine_name, "reused": True}
# Need new slot — evict if VRAM tight
await _maybe_evict(req.engine_name)
# Load model OUTSIDE pool_lock (blocking I/O)
engine = _create_engine_instance(req.engine_name)
if not engine:
raise HTTPException(500, f"Cannot create engine instance for '{req.engine_name}'")
start = time.time()
success = await asyncio.to_thread(engine.load_model, config)
elapsed = time.time() - start
if not success:
raise HTTPException(500, "Failed to load model")
slot = EngineSlot(
engine=engine,
engine_name=req.engine_name,
config=config,
pool_key=pool_key,
ref_count=1,
last_used=time.time(),
)
async with pool_lock:
# Double-check: another request may have loaded the same key concurrently
if pool_key in engine_pool:
engine.unload_model()
slot = engine_pool[pool_key]
slot.ref_count += 1
slot.last_used = time.time()
else:
engine_pool[pool_key] = slot
session.pool_key = pool_key
# Update compat shims
loaded_engine = slot.engine
loaded_engine_name = slot.engine_name
loaded_config = slot.config
log.info(f"Pool miss: loaded '{pool_key}' in {elapsed:.1f}s (pool size={len(engine_pool)})")
return {"success": True, "load_time_s": round(elapsed, 2),
"engine_name": req.engine_name, "reused": False}
@app.get("/api/keys")
async def list_keys():
"""Keys are stored in browser localStorage only. Server has no key info.
This endpoint returns an empty dict — it exists for backwards compatibility.
"""
return {}
@app.post("/api/admin/evict-all")
async def admin_evict_all(request: Request):
"""Force-evict all engine slots from VRAM (localhost admin only)."""
if request.client and request.client.host not in ("127.0.0.1", "::1"):
from fastapi import HTTPException
raise HTTPException(status_code=403, detail="localhost only")
async with pool_lock:
evicted = []
for key, slot in list(engine_pool.items()):
try:
slot.engine.unload_model()
except Exception as e:
log.warning(f"admin evict failed for '{key}': {e}")
del engine_pool[key]
evicted.append(key)
for session in sessions.values():
session.pool_key = None
global loaded_engine, loaded_engine_name, loaded_config
loaded_engine = None
loaded_engine_name = ""
loaded_config = {}
log.info(f"Admin force-evict: cleared {len(evicted)} slot(s): {evicted}")
return {"evicted": evicted}
@app.post("/api/engine/unload")
async def unload_engine(request: Request):
global loaded_engine, loaded_engine_name, loaded_config
session = _get_session(request)
async with pool_lock:
if session.pool_key and session.pool_key in engine_pool:
slot = engine_pool[session.pool_key]
slot.ref_count = max(0, slot.ref_count - 1)
if slot.ref_count == 0:
log.info(f"Immediate eviction (explicit unload): '{slot.engine_name}'")
try:
slot.engine.unload_model()
except Exception as e:
log.warning(f"unload_model() failed for '{slot.engine_name}': {e}")
if session.pool_key in engine_pool:
del engine_pool[session.pool_key]
session.pool_key = None
# Update compat shims
loaded_engine = None
loaded_engine_name = ""
loaded_config = {}
return {"success": True}
def _register_image(session: UserSession, pil_image: Image.Image, filename: str, save_path: Path) -> str:
"""Store a PIL image in the session's cache and return its image_id."""
image_id = str(uuid.uuid4())
image_data = {
"path": save_path,
"xml_path": None,
"pil_image": pil_image,
"width": pil_image.width,
"height": pil_image.height,
"filename": filename,
"lines": None,
}
session.image_cache[image_id] = image_data
global_image_cache[image_id] = image_data
return image_id
def _get_image_data(session: UserSession, image_id: str) -> Optional[dict]:
"""Return image data, tolerating missing cookies in embedded Space contexts."""
if image_id in session.image_cache:
return session.image_cache[image_id]
img_data = global_image_cache.get(image_id)
if img_data is not None:
session.image_cache[image_id] = img_data
return img_data
@app.post("/api/image/upload")
async def upload_image(
request: Request,
file: UploadFile = File(...),
max_dim: Optional[int] = Query(default=None, ge=100, description="Resize long edge to this many pixels (mobile upload only)"),
):
session = _get_session(request)
filename = file.filename or "upload"
is_pdf = (
filename.lower().endswith(".pdf") or
(file.content_type or "").startswith("application/pdf")
)
image_exts = {
".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".gif", ".webp"
}
is_image = (
(file.content_type or "").startswith("image/") or
Path(filename).suffix.lower() in image_exts
)
content = await file.read()
if len(content) > 200 * 1024 * 1024:
raise HTTPException(400, "File too large (max 200MB)")
# ── PDF: render each page as a separate image ──────────────────────────
if is_pdf:
if not PDF_AVAILABLE:
raise HTTPException(400, "PDF support requires PyMuPDF. Install with: pip install pymupdf")
try:
import asyncio
from concurrent.futures import ThreadPoolExecutor
def _render_pdf(data: bytes, stem: str, sess: UserSession) -> list:
mat = _fitz.Matrix(150 / 72, 150 / 72)
doc = _fitz.open(stream=data, filetype="pdf")
results = []
for i, page in enumerate(doc):
pix = page.get_pixmap(matrix=mat, colorspace=_fitz.csRGB)
pil_page = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
page_filename = f"{stem}_page{i+1:03d}.png"
save_path = UPLOAD_DIR / f"{uuid.uuid4()}.png"
pil_page.save(save_path)
pid = _register_image(sess, pil_page, page_filename, save_path)
results.append({
"image_id": pid,
"filename": page_filename,
"width": pil_page.width,
"height": pil_page.height,
"page": i + 1,
})
doc.close()
return results
stem = Path(filename).stem
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=1) as pool:
pages_out = await loop.run_in_executor(pool, _render_pdf, content, stem, session)
return {
"is_pdf": True,
"filename": filename,
"num_pages": len(pages_out),
"pages": pages_out,
}
except Exception as e:
raise HTTPException(400, f"Failed to render PDF: {e}")
# ── Regular image ───────────────────────────────────────────────────────
if not is_image:
raise HTTPException(400, "File must be an image or PDF")
ext = Path(filename).suffix or ".jpg"
save_path = UPLOAD_DIR / f"{uuid.uuid4()}{ext}"
save_path.write_bytes(content)
try:
pil_image = Image.open(save_path)
pil_image = ImageOps.exif_transpose(pil_image)
pil_image = pil_image.convert("RGB")
if max_dim and max(pil_image.width, pil_image.height) > max_dim:
pil_image.thumbnail((max_dim, max_dim), Image.LANCZOS)
pil_image.save(save_path)
except Exception as e:
save_path.unlink(missing_ok=True)
raise HTTPException(400, f"Invalid image: {e}")
image_id = _register_image(session, pil_image, filename, save_path)
return {
"image_id": image_id,
"width": pil_image.width,
"height": pil_image.height,
"filename": filename,
}
@app.post("/api/image/{image_id}/xml")
async def upload_xml(request: Request, image_id: str, file: UploadFile = File(...)):
"""Attach a PAGE XML file to an already-uploaded image."""
session = _get_session(request)
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, "Image not found — upload image first")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise HTTPException(400, "XML too large (max 10MB)")
xml_path = UPLOAD_DIR / f"{image_id}.xml"
xml_path.write_bytes(content)
img_data["xml_path"] = xml_path
return {"success": True, "filename": file.filename}
@app.get("/api/image/{image_id}")
async def get_image(request: Request, image_id: str):
session = _get_session(request)
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, "Image not found")
return FileResponse(str(img_data["path"]))
@app.get("/api/image/{image_id}/info")
async def image_info(request: Request, image_id: str):
session = _get_session(request)
d = _get_image_data(session, image_id)
if d is None:
raise HTTPException(404, "Image not found")
return {
"image_id": image_id,
"filename": d["filename"],
"width": d["width"],
"height": d["height"],
"has_xml": d["xml_path"] is not None,
}
async def _run_segmentation(img_data: dict, method: str, device: str = "cpu",
max_columns: int = 6,
split_width_fraction: float = 0.40,
text_direction: str = "horizontal-lr") -> dict:
"""
Shared segmentation helper. Runs the appropriate segmenter, stores
results in img_data, and returns a serialisable dict ready for SSE or JSON.
Also populates img_data["line_regions"] with a per-line region index list
so the transcription loop can tag each line with its column.
"""
if DEMO_MODE == "hf_space" and method == "kraken-blla":
method = "kraken"
device = "cpu"
pil_image = img_data["pil_image"]
xml_path = img_data.get("xml_path")
if DEMO_MODE == "hf_space" and xml_path is None and method == "hpp":
return await asyncio.to_thread(_run_demo_hpp_segmentation, img_data)
_import_segmenters()
regions: list = []
lines: list = []
xml_region_data: list = [] # TextRegion bboxes from PAGE XML (for visualization)
if xml_path is not None:
from inference_page import PageXMLSegmenter as _PXSeg
segmenter = _PXSeg(str(xml_path))
lines = await asyncio.to_thread(segmenter.segment_lines, pil_image)
source = "pagexml"
xml_region_data = getattr(segmenter, 'region_data', []) or []
elif method == "kraken-blla":
segmenter = KrakenLineSegmenter(device=device)
regions, lines = await asyncio.to_thread(
segmenter.segment_with_regions, pil_image,
device=device,
max_columns=max_columns,
split_width_fraction=split_width_fraction,
text_direction=text_direction,
)
source = "kraken-blla"
elif method == "kraken":
try:
segmenter = KrakenLineSegmenter()
# Use column-aware segmentation so multi-column pages read correctly
regions, lines = await asyncio.to_thread(
segmenter.segment_classical_with_regions, pil_image,
max_columns=max_columns,
)
source = "kraken"
except Exception as exc:
if DEMO_MODE == "hf_space":
log.warning("Kraken segmentation failed in HF Space; falling back to HPP: %s", exc)
return await asyncio.to_thread(_run_demo_hpp_segmentation, img_data, "hpp-fallback")
raise
else: # hpp
segmenter = LineSegmenter()
lines = await asyncio.to_thread(segmenter.segment_lines, pil_image)
source = "hpp"
if DEMO_MODE == "hf_space" and method == "kraken" and not lines:
log.warning("Kraken returned no lines in HF Space; falling back to HPP")
return await asyncio.to_thread(_run_demo_hpp_segmentation, img_data, "hpp-fallback")
# Build per-line region index (used by transcription loop for column view)
line_regions: list[int] = []
if regions:
offset = 0
for ri, r in enumerate(regions):
for _ in r.line_ids:
line_regions.append(ri)
offset += len(r.line_ids)
else:
line_regions = [0] * len(lines)
img_data["lines"] = lines
img_data["line_regions"] = line_regions
img_data["seg_source"] = source
# PAGE XML provides region bboxes directly; Kraken/blla provide SegRegion objects
if xml_region_data:
img_data["seg_regions"] = xml_region_data
elif regions:
img_data["seg_regions"] = [
{"id": r.id, "bbox": list(r.bbox), "num_lines": len(r.line_ids)}
for r in regions
]
else:
img_data["seg_regions"] = []
result: dict = {
"num_lines": len(lines),
"bboxes": [list(l.bbox) for l in lines],
"source": source,
}
if img_data["seg_regions"]:
result["regions"] = img_data["seg_regions"]
return result
def _run_demo_hpp_segmentation(img_data: dict, source: str = "hpp") -> dict:
"""Small dependency-light line segmenter for the hosted CPU demo fallback."""
pil_image = img_data["pil_image"]
gray = np.array(pil_image.convert("L"))
if gray.size == 0:
lines = []
else:
threshold = min(220, max(90, float(np.percentile(gray, 42))))
ink = gray < threshold
row_density = ink.mean(axis=1)
kernel = np.ones(9, dtype=np.float32) / 9.0
smooth = np.convolve(row_density, kernel, mode="same")
active_threshold = max(0.01, float(smooth.max()) * 0.13)
min_height = max(10, int(pil_image.height * 0.008))
bands = []
start = None
for y, value in enumerate(smooth):
if value > active_threshold and start is None:
start = y
elif (value <= active_threshold or y == len(smooth) - 1) and start is not None:
end = y if y == len(smooth) - 1 else y - 1
if end - start + 1 >= min_height:
bands.append((start, end))
start = None
lines = []
for y1, y2 in bands[:100]:
pad_y = max(3, int((y2 - y1 + 1) * 0.25))
top = max(0, y1 - pad_y)
bottom = min(pil_image.height, y2 + pad_y + 1)
band_ink = ink[top:bottom, :]
cols = np.where(band_ink.any(axis=0))[0]
if cols.size:
left = max(0, int(cols[0]) - 8)
right = min(pil_image.width, int(cols[-1]) + 9)
else:
left = 0
right = pil_image.width
bbox = (left, top, right, bottom)
lines.append(SimpleNamespace(
image=pil_image.crop(bbox),
bbox=bbox,
coords=None,
))
img_data["lines"] = lines
img_data["line_regions"] = [0] * len(lines)
img_data["seg_source"] = source
img_data["seg_regions"] = []
return {
"num_lines": len(lines),
"bboxes": [list(line.bbox) for line in lines],
"source": source,
}
@app.delete("/api/image/{image_id}/region/{region_index}")
async def delete_region(request: Request, image_id: str, region_index: int):
"""
Remove one detected region and its lines from the cached segmentation.
Returns updated segmentation data in the same format as /segment,
so the client can redraw the canvas.
"""
session = _get_session(request)
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, "Image not found")
seg_regions = img_data.get("seg_regions") or []
if not seg_regions:
raise HTTPException(400, "No segmentation data — run Segment first")
if region_index < 0 or region_index >= len(seg_regions):
raise HTTPException(400, f"Region index out of range (0–{len(seg_regions)-1})")
lines = img_data.get("lines") or []
line_regions = img_data.get("line_regions") or ([0] * len(lines))
# Keep lines that are NOT in the deleted region; re-index later regions
new_lines: list = []
new_line_regions: list = []
for line, lr in zip(lines, line_regions):
if lr == region_index:
continue
new_lines.append(line)
new_line_regions.append(lr if lr < region_index else lr - 1)
new_regions = [r for i, r in enumerate(seg_regions) if i != region_index]
img_data["lines"] = new_lines
img_data["line_regions"] = new_line_regions
img_data["seg_regions"] = new_regions
result: dict = {
"num_lines": len(new_lines),
"bboxes": [list(l.bbox) for l in new_lines],
"source": img_data.get("seg_source", "modified"),
}
if new_regions:
result["regions"] = new_regions
return result
@app.get("/api/image/{image_id}/segment")
async def segment_image(
request: Request,
image_id: str,
method: str = "kraken",
device: str = "cpu",
max_columns: int = 6,
split_width_fraction: float = 0.40,
text_direction: str = "horizontal-lr",
):
"""
Run segmentation only (no transcription) and return line bboxes as JSON.
Useful for previewing line layout before transcribing.
"""
session = _get_session(request)
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, "Image not found — upload first")
try:
return await _run_segmentation(img_data, method, device,
max_columns, split_width_fraction, text_direction)
except Exception as e:
raise HTTPException(500, f"Segmentation failed: {e}")
@app.post("/api/transcribe")
async def transcribe(request: Request, req: TranscribeRequest):
session = _get_session(request)
# Resolve engine from session's pool slot
if not session.pool_key or session.pool_key not in engine_pool:
# Fallback: check compat shims (e.g. auto-loaded engine, no session yet)
if not loaded_engine or not loaded_engine.is_model_loaded():
raise HTTPException(400, "No engine loaded")
slot = engine_pool.get(session.pool_key) if session.pool_key else None
# Build effective engine/config references
eff_engine = slot.engine if slot else loaded_engine
_base_config = slot.config if slot else loaded_config
# Merge live form overrides into a copy of the stored config so changes to
# runtime-only fields (custom_prompt, thinking_mode, temperature, …) take
# effect without requiring a model reload. Never overwrite security-sensitive
# keys that were set during load (api_key, provider, model, model_path, …).
_RELOAD_ONLY_KEYS = {"api_key", "provider", "model", "model_path", "model_source",
"base_model", "adapter", "model_name", "preset_id", "lang",
"use_gpu", "venv_path"}
if req.engine_config_overrides:
eff_config = dict(_base_config)
for k, v in req.engine_config_overrides.items():
if k not in _RELOAD_ONLY_KEYS:
eff_config[k] = v
else:
eff_config = _base_config
eff_engine_name = slot.engine_name if slot else loaded_engine_name
if not eff_engine or not eff_engine.is_model_loaded():
raise HTTPException(400, "No engine loaded")
img_data = _get_image_data(session, req.image_id)
if img_data is None:
raise HTTPException(404, "Image not found — upload first")
pil_image = img_data["pil_image"]
# Per-request cancel event (replaces global cancel_event)
request_id = str(uuid.uuid4())
cancel_evt = asyncio.Event()
session.cancel_events[request_id] = cancel_evt
async def event_stream():
_import_segmenters()
try:
# --- Segmentation ---
xml_path = img_data.get("xml_path") if req.use_pagexml else None
if not eff_engine.requires_line_segmentation() and not xml_path:
# Page-level engine with no PAGE XML — send whole page as single line
from inference_page import LineSegment
lines = [LineSegment(
image=pil_image,
bbox=(0, 0, pil_image.width, pil_image.height),
coords=None,
)]
img_data["lines"] = lines
img_data["line_regions"] = [0]
img_data["seg_source"] = "page"
img_data["seg_regions"] = []
yield _sse("segmentation", {
"num_lines": 1,
"bboxes": [[0, 0, pil_image.width, pil_image.height]],
"source": "page",
})
else:
# Reuse cached segmentation if method matches (e.g. user clicked Segment first)
cached_lines = img_data.get("lines")
cached_source = img_data.get("seg_source")
desired_source = "pagexml" if (xml_path and req.use_pagexml) else req.seg_method
if cached_lines and cached_source == desired_source:
lines = cached_lines
yield _sse("status", {"message": "Using cached segmentation..."})
seg_event: dict = {
"num_lines": len(lines),
"bboxes": [list(l.bbox) for l in lines],
"source": cached_source,
}
if img_data.get("seg_regions"):
seg_event["regions"] = img_data["seg_regions"]
yield _sse("segmentation", seg_event)
elif xml_path is not None:
yield _sse("status", {"message": "Reading line layout from PAGE XML..."})
seg_result = await _run_segmentation(img_data, "pagexml",
req.seg_device, req.max_columns,
req.split_width_fraction,
req.text_direction)
lines = img_data["lines"]
yield _sse("segmentation", seg_result)
else:
yield _sse("status", {"message": f"Segmenting with {req.seg_method}..."})
seg_result = await _run_segmentation(img_data, req.seg_method,
req.seg_device, req.max_columns,
req.split_width_fraction,
req.text_direction)
lines = img_data["lines"]
yield _sse("segmentation", seg_result)
# --- Transcription ---
results = []
token_usage: Dict[str, Any] = {}
start_time = time.time()
line_regions = img_data.get("line_regions") or ([0] * len(lines))
for i, line in enumerate(lines):
# Check for cancellation before each line
if cancel_evt.is_set():
yield _sse("cancelled", {})
return
line_img = line.image if line.image is not None else pil_image.crop(line.bbox)
img_array = np.array(line_img.convert("RGB"))
# Use slot lock to serialize access to this engine instance
if slot:
async with slot.lock:
slot.last_used = time.time()
result = await asyncio.to_thread(
eff_engine.transcribe_line, img_array, eff_config
)
else:
result = await asyncio.to_thread(
eff_engine.transcribe_line, img_array, eff_config
)
text = str(result.text) if hasattr(result, "text") else str(result)
confidence = None
if hasattr(result, "confidence") and result.confidence is not None:
confidence = float(result.confidence)
if confidence > 1:
confidence = confidence / 100.0
# Accumulate token usage and extract thinking text from API engines (e.g. Gemini)
thinking_text = None
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
tu = result.metadata.get("token_usage")
if tu:
for k, v in tu.items():
if v is not None:
token_usage[k] = token_usage.get(k, 0) + v
thinking_text = result.metadata.get("thinking_text")
line_data = {
"index": i,
"text": text,
"confidence": confidence,
"bbox": list(line.bbox),
"region": line_regions[i] if i < len(line_regions) else 0,
}
if thinking_text:
line_data["thinking_text"] = thinking_text
results.append(line_data)
progress_data: Dict[str, Any] = {
"current": i + 1,
"total": len(lines),
"line": line_data,
}
if token_usage:
progress_data["token_usage"] = dict(token_usage)
yield _sse("progress", progress_data)
# Check for cancellation after each line's progress event
if cancel_evt.is_set():
yield _sse("cancelled", {})
return
# Store completed results in session image_cache for export
img_data["results"] = results
elapsed = time.time() - start_time
complete_data: Dict[str, Any] = {
"lines": results,
"total_time_s": round(elapsed, 2),
"engine": eff_engine_name,
}
if token_usage:
complete_data["token_usage"] = token_usage
yield _sse("complete", complete_data)
except Exception as e:
log.exception("Transcription error")
yield _sse("error", {"message": str(e)})
finally:
# Clean up this request's cancel event
session.cancel_events.pop(request_id, None)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering if behind proxy
},
)
@app.post("/api/transcribe/cancel")
async def cancel_transcription(request: Request):
"""Signal all running transcriptions for this session to stop."""
session = _get_session(request)
for evt in session.cancel_events.values():
evt.set()
return {"success": True}
@app.post("/api/image/{image_id}/export-xml")
async def export_xml(request: Request, image_id: str):
"""Export transcription results for image_id as PAGE XML."""
session = _get_session(request)
pretty, stem = _build_xml_bytes(session, image_id)
return Response(
content=pretty,
media_type="application/xml",
headers={"Content-Disposition": f'attachment; filename="{stem}.xml"'},
)
def _build_xml_bytes(session: UserSession, image_id: str) -> tuple[bytes, str]:
"""Return (xml_bytes, stem) for a cached image, or raise HTTPException."""
import xml.etree.ElementTree as ET
from xml.dom import minidom
from page_xml_exporter import PageXMLExporter
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, f"Image {image_id} not found")
results = img_data.get("results")
if not results:
raise HTTPException(400, f"No results for {image_id}")
filename = img_data.get("filename", img_data["path"].name)
width = img_data["width"]
height = img_data["height"]
class _SegProxy:
__slots__ = ("bbox", "coords", "text", "confidence")
def __init__(self, r):
bbox = r.get("bbox")
self.bbox = tuple(bbox) if bbox else (0, 0, width, height)
self.coords = None
self.text = r.get("text", "")
self.confidence = r.get("confidence")
segments = [_SegProxy(r) for r in results]
exporter = PageXMLExporter(str(filename), width, height)
root, page = exporter._make_root("Polyscriptor Web UI", None)
reading_order = ET.SubElement(page, 'ReadingOrder')
ordered_group = ET.SubElement(reading_order, 'OrderedGroup',
{'id': 'ro_1', 'caption': 'Regions reading order'})
ET.SubElement(ordered_group, 'RegionRefIndexed', {'index': '0', 'regionRef': 'region_1'})
text_region = ET.SubElement(page, 'TextRegion',
{'id': 'region_1', 'type': 'paragraph', 'custom': 'readingOrder {index:0;}'})
if segments:
x1 = min(s.bbox[0] for s in segments)
y1 = min(s.bbox[1] for s in segments)
x2 = max(s.bbox[2] for s in segments)
y2 = max(s.bbox[3] for s in segments)
ET.SubElement(text_region, 'Coords').set('points', f'{x1},{y1} {x2},{y1} {x2},{y2} {x1},{y2}')
for idx, seg in enumerate(segments):
exporter._add_text_line(text_region, f'line_{idx + 1}', seg, seg.text, idx)
xml_bytes = ET.tostring(root, encoding='utf-8', method='xml')
pretty = minidom.parseString(xml_bytes).toprettyxml(indent=' ', encoding='utf-8')
return pretty, Path(filename).stem
def _build_thinking_bytes(session: UserSession, image_id: str) -> tuple[bytes, str]:
"""Return (thinking_bytes, stem) for a cached image, or raise HTTPException(404) if no thinking."""
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, f"Image {image_id} not found")
results = img_data.get("results")
if not results:
raise HTTPException(400, f"No results for {image_id}")
filename = img_data.get("filename", img_data["path"].name)
stem = Path(filename).stem
blocks = []
for i, r in enumerate(results):
t = r.get("thinking_text", "")
if t:
if len(results) > 1:
blocks.append(f"=== Line {i + 1} ===\n{t}")
else:
blocks.append(t)
if not blocks:
raise HTTPException(404, f"No thinking text for {image_id}")
return "\n\n".join(blocks).encode("utf-8"), stem
def _build_txt_bytes(session: UserSession, image_id: str) -> tuple[bytes, str]:
"""Return (txt_bytes, stem) for a cached image, or raise HTTPException."""
img_data = _get_image_data(session, image_id)
if img_data is None:
raise HTTPException(404, f"Image {image_id} not found")
results = img_data.get("results")
if not results:
raise HTTPException(400, f"No results for {image_id}")
filename = img_data.get("filename", img_data["path"].name)
text = "\n".join(r.get("text", "") for r in results)
return text.encode("utf-8"), Path(filename).stem
class BatchXMLRequest(BaseModel):
image_ids: list[str]
@app.post("/api/batch/export-thinking")
async def batch_export_thinking(request: Request, req: BatchXMLRequest):
"""Return a ZIP archive containing one thinking-text file per image (skips pages without thinking)."""
session = _get_session(request)
import zipfile, io
buf = io.BytesIO()
with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
for image_id in req.image_ids:
try:
thinking_bytes, stem = _build_thinking_bytes(session, image_id)
zf.writestr(f"{stem}_thinking.txt", thinking_bytes)
except HTTPException:
pass # skip pages without thinking
buf.seek(0)
return Response(
content=buf.read(),
media_type="application/zip",
headers={"Content-Disposition": 'attachment; filename="batch_thinking.zip"'},
)
@app.post("/api/batch/export-txt")
async def batch_export_txt(request: Request, req: BatchXMLRequest):
"""Return a ZIP archive containing one plain-text file per image."""
session = _get_session(request)
import zipfile, io
buf = io.BytesIO()
with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
for image_id in req.image_ids:
try:
txt_bytes, stem = _build_txt_bytes(session, image_id)
zf.writestr(f"{stem}.txt", txt_bytes)
except HTTPException:
pass # skip images without results
buf.seek(0)
return Response(
content=buf.read(),
media_type="application/zip",
headers={"Content-Disposition": 'attachment; filename="batch_export_txt.zip"'},
)
@app.post("/api/batch/export-xml")
async def batch_export_xml(request: Request, req: BatchXMLRequest):
"""Return a ZIP archive containing one PAGE XML file per image."""
session = _get_session(request)
import zipfile, io
buf = io.BytesIO()
with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
for image_id in req.image_ids:
try:
xml_bytes, stem = _build_xml_bytes(session, image_id)
zf.writestr(f"{stem}.xml", xml_bytes)
except HTTPException:
pass # skip images without results
buf.seek(0)
return Response(
content=buf.read(),
media_type="application/zip",
headers={"Content-Disposition": 'attachment; filename="batch_export.zip"'},
)
@app.get("/api/session")
async def session_info(request: Request):
"""Return info about the current session (useful for debugging)."""
session = _get_session(request)
return {
"session_id": session.session_id[:8] + "...",
"images": len(session.image_cache),
"active_transcriptions": len(session.cancel_events),
"pool_key": session.pool_key,
"created_at": session.created_at,
"last_active": session.last_active,
"total_sessions": len(sessions),
}
@app.get("/api/engine/pool")
async def pool_status():
"""Return current engine pool state (admin/debug endpoint)."""
slots = []
for key, slot in engine_pool.items():
slots.append({
"pool_key": key,
"engine_name": slot.engine_name,
"ref_count": slot.ref_count,
"loaded": slot.engine.is_model_loaded(),
"last_used": slot.last_used,
"age_s": round(time.time() - slot.last_used, 0),
})
return {
"pool_size": len(engine_pool),
"slots": slots,
"total_sessions": len(sessions),
}
@app.get("/api/kraken/presets")
async def kraken_presets():
"""Return list of available Kraken model presets (local + Zenodo)."""
try:
from engines.kraken_engine import KRAKEN_MODELS
except ImportError:
return {"presets": []}
presets = []
for model_id, info in KRAKEN_MODELS.items():
presets.append({
"id": model_id,
"label": info.get("description", model_id),
"language": info.get("language", ""),
"source": info.get("source", ""),
})
return {"presets": presets}
@app.post("/api/models/upload")
async def upload_model(file: UploadFile = File(...)):
"""Upload a Kraken .mlmodel file to the server's models/kraken_uploads/ directory."""
filename = file.filename or "model.mlmodel"
if not filename.lower().endswith(".mlmodel"):
raise HTTPException(400, "Only .mlmodel files are accepted")
content = await file.read()
if len(content) > 500 * 1024 * 1024:
raise HTTPException(400, "File too large (max 500 MB)")
upload_dir = PROJECT_ROOT / "models" / "kraken_uploads"
upload_dir.mkdir(parents=True, exist_ok=True)
# Sanitize filename — keep only safe characters
safe_name = Path(filename).name
safe_name = "".join(c for c in safe_name if c.isalnum() or c in "._- ")
safe_name = safe_name.strip() or "uploaded.mlmodel"
dest = upload_dir / safe_name
dest.write_bytes(content)
log.info(f"Uploaded Kraken model: {dest} ({len(content)} bytes)")
rel_path = str(dest.relative_to(PROJECT_ROOT)) # e.g. models/kraken_uploads/foo.mlmodel
return {
"path": rel_path,
"filename": safe_name,
"size": len(content),
"options": _scan_kraken_models(), # refreshed list for frontend to repopulate select
}
@app.get("/api/gpu")
async def gpu_status():
try:
import torch
if not torch.cuda.is_available():
return {"available": False, "gpus": []}
# pynvml (nvidia-ml-py) for utilization %; graceful fallback if missing
nvml_utils: dict[int, dict] = {}
try:
import pynvml
pynvml.nvmlInit()
for _i in range(pynvml.nvmlDeviceGetCount()):
h = pynvml.nvmlDeviceGetHandleByIndex(_i)
u = pynvml.nvmlDeviceGetUtilizationRates(h)
nvml_utils[_i] = {"gpu_pct": u.gpu, "mem_pct": u.memory}
except Exception:
pass # pynvml unavailable — utilization fields omitted
gpus = []
for i in range(torch.cuda.device_count()):
free, total = torch.cuda.mem_get_info(i)
entry: dict = {
"index": i,
"name": torch.cuda.get_device_name(i),
"memory_total_mb": round(total / 1e6),
"memory_used_mb": round((total - free) / 1e6),
"memory_free_mb": round(free / 1e6),
}
if i in nvml_utils:
entry["utilization_gpu_pct"] = nvml_utils[i]["gpu_pct"]
entry["utilization_mem_pct"] = nvml_utils[i]["mem_pct"]
gpus.append(entry)
return {"available": True, "gpus": gpus}
except Exception:
return {"available": False, "gpus": []}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _sse(event: str, data: dict) -> str:
"""Format a Server-Sent Event."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"