sv-task / app.py
lamossta's picture
logger and main classes for fe and be
61af0ed
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI
from src.models.hf_download import download_all
from src.pipelines.predict_pipeline import PredictPipeline
from src.pipelines.predict_all_pipeline import PredictAllPipeline
from src.pipelines.fasttext_pipeline import FastTextPipeline
from src.api.health import router as health_router
from src.api.predict import router as predict_router
from src.api.predict_all import router as predict_all_router
load_dotenv()
log = logging.getLogger(__name__)
MODELS_ROOT = Path("models")
VALID_MODES = ("marker", "qa_m", "qa_b")
@asynccontextmanager
async def lifespan(app: FastAPI):
downloaded = download_all()
all_pipeline = PredictAllPipeline()
for mode in VALID_MODES:
local_dir = downloaded.get(mode)
if local_dir is None:
local_dir = MODELS_ROOT / mode
onnx_path = local_dir / "model.onnx"
if not onnx_path.exists():
log.warning(f"{mode} not available — skipping")
continue
log.info(f"Loading PredictPipeline ({mode})")
app.state.__dict__[f"pipeline_{mode}"] = PredictPipeline(
onnx_path=onnx_path,
mode=mode,
model_name=mode,
)
all_pipeline.add_model(mode, onnx_path, mode)
fasttext_path = MODELS_ROOT / "fasttext" / "model.bin"
if fasttext_path.exists():
log.info("Loading FastTextPipeline")
app.state.pipeline_fasttext = FastTextPipeline(fasttext_path, "fasttext")
all_pipeline.add_fasttext("fasttext", fasttext_path)
else:
log.warning("fastText model not available — skipping")
app.state.all_models_pipeline = all_pipeline
available = list(all_pipeline.models.keys()) + list(all_pipeline.fasttext_models.keys())
log.info(f"Available models: {available}")
yield
app = FastAPI(
title="Entity Sentiment Classification API",
description="Classify sentiment (positive, neutral, negative) for entities in text.",
version="1.0.0",
lifespan=lifespan,
root_path="/api",
)
app.include_router(health_router)
app.include_router(predict_router)
app.include_router(predict_all_router)