import logging from contextlib import asynccontextmanager import torch from dotenv import load_dotenv from fastapi import FastAPI from src.api import fix_newlines, fix_newlines_all_models, health from src.models.export_and_download import download_all_models from src.pipelines.all_models_pipeline import AllModelsPipeline from src.pipelines.one_model_pipeline import OneModelPipeline from src.pipelines.sat_loader import load_sat load_dotenv() log = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): device = "cuda" if torch.cuda.is_available() else "cpu" log.info(f"Loading SAT-3L-SM on {device} ...") sat = load_sat(device=device) downloaded = download_all_models() # OneModelPipeline — uses bert bert_dir = downloaded.get("bert") if bert_dir and (bert_dir / "model.onnx").exists(): log.info("Loading OneModelPipeline (bert)") app.state.one_model_pipeline = OneModelPipeline( onnx_path=bert_dir / "model.onnx", tokenizer_path=bert_dir, sat_model=sat, model_name="bert", ) else: log.warning("bert not available — OneModelPipeline disabled") app.state.one_model_pipeline = None # AllModelsPipeline — uses all downloaded models all_pipeline = AllModelsPipeline(sat_model=sat) for name, local_dir in downloaded.items(): onnx_path = local_dir / "model.onnx" if onnx_path.exists(): log.info(f"Adding {name} to AllModelsPipeline") all_pipeline.add_model(name, onnx_path, local_dir) else: log.warning(f"Skipping {name}: model.onnx not found in {local_dir}") app.state.all_models_pipeline = all_pipeline yield app = FastAPI( title="Newline Fixer API", description="ML service for fixing newline placement in English text", lifespan=lifespan, root_path="/api", ) app.include_router(health.router) app.include_router(fix_newlines.router) app.include_router(fix_newlines_all_models.router)