bc-test / main.py
lamossta's picture
updated readme and added report
8159f14
import logging
from contextlib import asynccontextmanager
import torch
from dotenv import load_dotenv
from fastapi import FastAPI
from src.api import fix_newlines, fix_newlines_all_models, health
from src.models.export_and_download import download_all_models
from src.pipelines.all_models_pipeline import AllModelsPipeline
from src.pipelines.one_model_pipeline import OneModelPipeline
from src.pipelines.sat_loader import load_sat
load_dotenv()
log = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info(f"Loading SAT-3L-SM on {device} ...")
sat = load_sat(device=device)
downloaded = download_all_models()
# OneModelPipeline β€” uses bert
bert_dir = downloaded.get("bert")
if bert_dir and (bert_dir / "model.onnx").exists():
log.info("Loading OneModelPipeline (bert)")
app.state.one_model_pipeline = OneModelPipeline(
onnx_path=bert_dir / "model.onnx",
tokenizer_path=bert_dir,
sat_model=sat,
model_name="bert",
)
else:
log.warning("bert not available β€” OneModelPipeline disabled")
app.state.one_model_pipeline = None
# AllModelsPipeline β€” uses all downloaded models
all_pipeline = AllModelsPipeline(sat_model=sat)
for name, local_dir in downloaded.items():
onnx_path = local_dir / "model.onnx"
if onnx_path.exists():
log.info(f"Adding {name} to AllModelsPipeline")
all_pipeline.add_model(name, onnx_path, local_dir)
else:
log.warning(f"Skipping {name}: model.onnx not found in {local_dir}")
app.state.all_models_pipeline = all_pipeline
yield
app = FastAPI(
title="Newline Fixer API",
description="ML service for fixing newline placement in English text",
lifespan=lifespan,
root_path="/api",
)
app.include_router(health.router)
app.include_router(fix_newlines.router)
app.include_router(fix_newlines_all_models.router)