File size: 2,259 Bytes
61af0ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI
from src.models.hf_download import download_all
from src.pipelines.predict_pipeline import PredictPipeline
from src.pipelines.predict_all_pipeline import PredictAllPipeline
from src.pipelines.fasttext_pipeline import FastTextPipeline
from src.api.health import router as health_router
from src.api.predict import router as predict_router
from src.api.predict_all import router as predict_all_router


load_dotenv()
log = logging.getLogger(__name__)

MODELS_ROOT = Path("models")
VALID_MODES = ("marker", "qa_m", "qa_b")


@asynccontextmanager
async def lifespan(app: FastAPI):
    downloaded = download_all()

    all_pipeline = PredictAllPipeline()

    for mode in VALID_MODES:
        local_dir = downloaded.get(mode)
        if local_dir is None:
            local_dir = MODELS_ROOT / mode

        onnx_path = local_dir / "model.onnx"
        if not onnx_path.exists():
            log.warning(f"{mode} not available — skipping")
            continue

        log.info(f"Loading PredictPipeline ({mode})")
        app.state.__dict__[f"pipeline_{mode}"] = PredictPipeline(
            onnx_path=onnx_path,
            mode=mode,
            model_name=mode,
        )

        all_pipeline.add_model(mode, onnx_path, mode)

    fasttext_path = MODELS_ROOT / "fasttext" / "model.bin"
    if fasttext_path.exists():
        log.info("Loading FastTextPipeline")
        app.state.pipeline_fasttext = FastTextPipeline(fasttext_path, "fasttext")
        all_pipeline.add_fasttext("fasttext", fasttext_path)
    else:
        log.warning("fastText model not available — skipping")

    app.state.all_models_pipeline = all_pipeline
    available = list(all_pipeline.models.keys()) + list(all_pipeline.fasttext_models.keys())
    log.info(f"Available models: {available}")

    yield


app = FastAPI(
    title="Entity Sentiment Classification API",
    description="Classify sentiment (positive, neutral, negative) for entities in text.",
    version="1.0.0",
    lifespan=lifespan,
    root_path="/api",
)

app.include_router(health_router)
app.include_router(predict_router)
app.include_router(predict_all_router)