import logging from contextlib import asynccontextmanager from pathlib import Path from dotenv import load_dotenv from fastapi import FastAPI from src.models.hf_download import download_all from src.pipelines.predict_pipeline import PredictPipeline from src.pipelines.predict_all_pipeline import PredictAllPipeline from src.pipelines.fasttext_pipeline import FastTextPipeline from src.api.health import router as health_router from src.api.predict import router as predict_router from src.api.predict_all import router as predict_all_router load_dotenv() log = logging.getLogger(__name__) MODELS_ROOT = Path("models") VALID_MODES = ("marker", "qa_m", "qa_b") @asynccontextmanager async def lifespan(app: FastAPI): downloaded = download_all() all_pipeline = PredictAllPipeline() for mode in VALID_MODES: local_dir = downloaded.get(mode) if local_dir is None: local_dir = MODELS_ROOT / mode onnx_path = local_dir / "model.onnx" if not onnx_path.exists(): log.warning(f"{mode} not available — skipping") continue log.info(f"Loading PredictPipeline ({mode})") app.state.__dict__[f"pipeline_{mode}"] = PredictPipeline( onnx_path=onnx_path, mode=mode, model_name=mode, ) all_pipeline.add_model(mode, onnx_path, mode) fasttext_path = MODELS_ROOT / "fasttext" / "model.bin" if fasttext_path.exists(): log.info("Loading FastTextPipeline") app.state.pipeline_fasttext = FastTextPipeline(fasttext_path, "fasttext") all_pipeline.add_fasttext("fasttext", fasttext_path) else: log.warning("fastText model not available — skipping") app.state.all_models_pipeline = all_pipeline available = list(all_pipeline.models.keys()) + list(all_pipeline.fasttext_models.keys()) log.info(f"Available models: {available}") yield app = FastAPI( title="Entity Sentiment Classification API", description="Classify sentiment (positive, neutral, negative) for entities in text.", version="1.0.0", lifespan=lifespan, root_path="/api", ) app.include_router(health_router) app.include_router(predict_router) app.include_router(predict_all_router)