| import logging |
| from contextlib import asynccontextmanager |
|
|
| import torch |
| from dotenv import load_dotenv |
| from fastapi import FastAPI |
|
|
| from src.api import fix_newlines, fix_newlines_all_models, health |
| from src.models.export_and_download import download_all_models |
| from src.pipelines.all_models_pipeline import AllModelsPipeline |
| from src.pipelines.one_model_pipeline import OneModelPipeline |
| from src.pipelines.sat_loader import load_sat |
|
|
| load_dotenv() |
| log = logging.getLogger(__name__) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| log.info(f"Loading SAT-3L-SM on {device} ...") |
| sat = load_sat(device=device) |
|
|
| downloaded = download_all_models() |
|
|
| |
| bert_dir = downloaded.get("bert") |
| if bert_dir and (bert_dir / "model.onnx").exists(): |
| log.info("Loading OneModelPipeline (bert)") |
| app.state.one_model_pipeline = OneModelPipeline( |
| onnx_path=bert_dir / "model.onnx", |
| tokenizer_path=bert_dir, |
| sat_model=sat, |
| model_name="bert", |
| ) |
| else: |
| log.warning("bert not available β OneModelPipeline disabled") |
| app.state.one_model_pipeline = None |
|
|
| |
| all_pipeline = AllModelsPipeline(sat_model=sat) |
| for name, local_dir in downloaded.items(): |
| onnx_path = local_dir / "model.onnx" |
| if onnx_path.exists(): |
| log.info(f"Adding {name} to AllModelsPipeline") |
| all_pipeline.add_model(name, onnx_path, local_dir) |
| else: |
| log.warning(f"Skipping {name}: model.onnx not found in {local_dir}") |
| app.state.all_models_pipeline = all_pipeline |
|
|
| yield |
|
|
|
|
| app = FastAPI( |
| title="Newline Fixer API", |
| description="ML service for fixing newline placement in English text", |
| lifespan=lifespan, |
| root_path="/api", |
| ) |
|
|
| app.include_router(health.router) |
| app.include_router(fix_newlines.router) |
| app.include_router(fix_newlines_all_models.router) |
|
|