Spaces:
Configuration error
Configuration error
feat(backend): pull weights from HuggingFace Hub at startup when BACKEND_WEIGHTS_HUB_REPO is set
2461f82 | """FastAPI application entrypoint. | |
| Run locally with:: | |
| uvicorn --app-dir backend app.main:app --host 0.0.0.0 --port 8000 --reload | |
| Lifespan order: | |
| 1. Load YAML ``AppConfig`` (research-side hyperparameters). | |
| 2. Load weights + tokenizer into a ``CaptionPredictor`` singleton. | |
| 3. Optionally warmup so the first request doesn't pay TF's lazy build cost. | |
| 4. Wrap the predictor in a ``PredictorService`` and stash on app state. | |
| The singleton lives on ``app.state.predictor_service``; routes pull it | |
| through a ``Depends`` so tests can override the dependency cleanly. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import AsyncIterator | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from app.api.routes import router | |
| from app.core.config import BackendSettings, get_backend_settings | |
| from app.core.logging import RequestContextMiddleware, configure_app_logging | |
| from app.services.predictor_service import PredictorService | |
| from app.services.weights_loader import resolve_weights | |
| from captioning.config import load_config | |
| from captioning.config.schema import AppConfig | |
| from captioning.inference import CaptionPredictor | |
| from captioning.utils import get_logger | |
| log = get_logger(__name__) | |
| async def lifespan(app: FastAPI) -> AsyncIterator[None]: | |
| """Load the predictor at startup, release it at shutdown.""" | |
| settings: BackendSettings = app.state.backend_settings | |
| config: AppConfig = app.state.app_config | |
| weights_path, tokenizer_dir = resolve_weights(settings) | |
| log.info( | |
| "predictor_loading", | |
| weights=str(weights_path), | |
| tokenizer_dir=str(tokenizer_dir), | |
| model_version=settings.model_version, | |
| ) | |
| predictor = CaptionPredictor.from_artifacts( | |
| weights_path=weights_path, | |
| tokenizer_dir=tokenizer_dir, | |
| config=config, | |
| ) | |
| if settings.warmup: | |
| predictor.warmup() | |
| app.state.predictor_service = PredictorService( | |
| predictor=predictor, | |
| model_version=settings.model_version, | |
| max_upload_bytes=config.serve.max_upload_bytes, | |
| ) | |
| log.info("predictor_ready", model_version=settings.model_version) | |
| try: | |
| yield | |
| finally: | |
| app.state.predictor_service = None | |
| log.info("predictor_unloaded") | |
| def create_app() -> FastAPI: | |
| """Build the FastAPI app. Factory form so tests can construct fresh apps.""" | |
| configure_app_logging() | |
| settings = get_backend_settings() | |
| config = load_config(settings.config_path) | |
| app = FastAPI( | |
| title="Image Captioning API", | |
| version=settings.api_version, | |
| description=( | |
| "Production-grade inference service for the IEEE-published " | |
| "CNN+Transformer image captioning model." | |
| ), | |
| lifespan=lifespan, | |
| ) | |
| app.state.backend_settings = settings | |
| app.state.app_config = config | |
| app.state.predictor_service = None | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=config.serve.cors_allowed_origins, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| allow_credentials=False, | |
| ) | |
| app.add_middleware(RequestContextMiddleware) | |
| app.include_router(router) | |
| return app | |
| app = create_app() | |