| import logging |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
| from dotenv import load_dotenv |
| from fastapi import FastAPI |
| from src.models.hf_download import download_all |
| from src.pipelines.predict_pipeline import PredictPipeline |
| from src.pipelines.predict_all_pipeline import PredictAllPipeline |
| from src.pipelines.fasttext_pipeline import FastTextPipeline |
| from src.api.health import router as health_router |
| from src.api.predict import router as predict_router |
| from src.api.predict_all import router as predict_all_router |
|
|
|
|
| load_dotenv() |
| log = logging.getLogger(__name__) |
|
|
| MODELS_ROOT = Path("models") |
| VALID_MODES = ("marker", "qa_m", "qa_b") |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| downloaded = download_all() |
|
|
| all_pipeline = PredictAllPipeline() |
|
|
| for mode in VALID_MODES: |
| local_dir = downloaded.get(mode) |
| if local_dir is None: |
| local_dir = MODELS_ROOT / mode |
|
|
| onnx_path = local_dir / "model.onnx" |
| if not onnx_path.exists(): |
| log.warning(f"{mode} not available — skipping") |
| continue |
|
|
| log.info(f"Loading PredictPipeline ({mode})") |
| app.state.__dict__[f"pipeline_{mode}"] = PredictPipeline( |
| onnx_path=onnx_path, |
| mode=mode, |
| model_name=mode, |
| ) |
|
|
| all_pipeline.add_model(mode, onnx_path, mode) |
|
|
| fasttext_path = MODELS_ROOT / "fasttext" / "model.bin" |
| if fasttext_path.exists(): |
| log.info("Loading FastTextPipeline") |
| app.state.pipeline_fasttext = FastTextPipeline(fasttext_path, "fasttext") |
| all_pipeline.add_fasttext("fasttext", fasttext_path) |
| else: |
| log.warning("fastText model not available — skipping") |
|
|
| app.state.all_models_pipeline = all_pipeline |
| available = list(all_pipeline.models.keys()) + list(all_pipeline.fasttext_models.keys()) |
| log.info(f"Available models: {available}") |
|
|
| yield |
|
|
|
|
| app = FastAPI( |
| title="Entity Sentiment Classification API", |
| description="Classify sentiment (positive, neutral, negative) for entities in text.", |
| version="1.0.0", |
| lifespan=lifespan, |
| root_path="/api", |
| ) |
|
|
| app.include_router(health_router) |
| app.include_router(predict_router) |
| app.include_router(predict_all_router) |
|
|