| import collections
|
| import logging
|
| import os
|
| import sys
|
| import threading
|
| import time
|
| from contextlib import asynccontextmanager
|
| from pathlib import Path
|
|
|
| from dotenv import load_dotenv
|
| from fastapi import Depends, FastAPI, Header, HTTPException
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from fastapi.responses import FileResponse, JSONResponse
|
|
|
| load_dotenv()
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format="%(asctime)s | %(levelname)-8s | %(message)s",
|
| datefmt="%H:%M:%S",
|
| handlers=[logging.StreamHandler(sys.stdout)],
|
| )
|
| log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
| PORT = int(os.environ.get("PORT", 8000))
|
| API_KEY = os.environ.get("API_KEY", "")
|
| MODEL_VERSION = os.environ.get("MODEL_VERSION", "v1.0.0")
|
| HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| DEVICE = os.environ.get("DEVICE", "cpu")
|
|
|
|
|
| RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", 60))
|
| RATE_LIMIT_WINDOW_S = int(os.environ.get("RATE_LIMIT_WINDOW_S", 60))
|
| _rate_buckets: dict = {}
|
| _rate_lock: threading.Lock = threading.Lock()
|
|
|
| ONNX_PATH = Path("models/onnx/immuniwatch_classifier.onnx")
|
| THRESHOLDS_PATH = Path("models/onnx/thresholds.json")
|
| CONFIG_PATH = Path("models/onnx/model_config.json")
|
| LORA_REPO = "AHFIDAILabs/immuniwatch-lora-classifier"
|
|
|
|
|
| _start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _download_model_files() -> None:
|
| files = [
|
| ("immuniwatch_classifier.onnx", ONNX_PATH),
|
| ("immuniwatch_classifier.onnx.data", ONNX_PATH.parent / "immuniwatch_classifier.onnx.data"),
|
| ("thresholds.json", THRESHOLDS_PATH),
|
| ("model_config.json", CONFIG_PATH),
|
| ]
|
| missing = [(fname, path) for fname, path in files if not path.exists()]
|
| if not missing:
|
| return
|
|
|
| log.info("Downloading %d model file(s) from %s ...", len(missing), LORA_REPO)
|
| try:
|
| from huggingface_hub import hf_hub_download
|
| for fname, path in missing:
|
| log.info(" -> %s", fname)
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| hf_hub_download(
|
| repo_id= LORA_REPO,
|
| filename= fname,
|
| local_dir= str(path.parent),
|
| token= HF_TOKEN or None,
|
| )
|
| log.info("Model files ready.")
|
| except Exception as exc:
|
| log.error("Model download failed: %s", exc)
|
| log.error("Upload ONNX files to %s on HuggingFace Hub first.", LORA_REPO)
|
|
|
|
|
|
|
|
|
|
|
| def _start_ingestion_worker() -> None:
|
| bluesky_handle = os.environ.get("BLUESKY_HANDLE", "")
|
| youtube_key = os.environ.get("YOUTUBE_API_KEY", "")
|
| if not bluesky_handle and not youtube_key:
|
| log.info("No connector credentials β ingestion worker not started.")
|
| return
|
| try:
|
| from src.ingestion.direct_runner import run as run_ingestion
|
| t = threading.Thread(target=run_ingestion, daemon=True, name="ingestion-worker")
|
| t.start()
|
| log.info("Ingestion worker started in background thread.")
|
| except Exception as exc:
|
| log.warning("Ingestion worker failed to start: %s", exc)
|
|
|
|
|
| @asynccontextmanager
|
| async def lifespan(app: FastAPI):
|
| from src.models.classifier import load as load_classifier
|
| from src.intelligence.rag import preload_embedder
|
|
|
| log.info("Starting ImmuniWatch ML Service v%s", MODEL_VERSION)
|
| _download_model_files()
|
| load_classifier(
|
| onnx_path= str(ONNX_PATH),
|
| thresholds_path=str(THRESHOLDS_PATH),
|
| config_path= str(CONFIG_PATH),
|
| tokenizer_repo= LORA_REPO,
|
| hf_token= HF_TOKEN or None,
|
| )
|
| preload_embedder()
|
| _start_ingestion_worker()
|
| log.info("Service ready on port %d", PORT)
|
| yield
|
| log.info("Service shutting down.")
|
|
|
|
|
|
|
|
|
|
|
| app = FastAPI(
|
| title="ImmuniWatch Nigeria β ML Service",
|
| version=MODEL_VERSION,
|
| lifespan=lifespan,
|
| docs_url="/docs",
|
| redoc_url=None,
|
| )
|
|
|
|
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=False,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _check_rate_limit(key: str) -> tuple:
|
| now = time.time()
|
| with _rate_lock:
|
| if key not in _rate_buckets:
|
| _rate_buckets[key] = collections.deque()
|
| bucket = _rate_buckets[key]
|
| while bucket and bucket[0] < now - RATE_LIMIT_WINDOW_S:
|
| bucket.popleft()
|
| if len(bucket) >= RATE_LIMIT_REQUESTS:
|
| retry_after = int(RATE_LIMIT_WINDOW_S - (now - bucket[0])) + 1
|
| return False, retry_after
|
| bucket.append(now)
|
| return True, 0
|
|
|
|
|
| async def require_api_key(x_ml_api_key: str = Header(default=None)):
|
| if x_ml_api_key is None:
|
| log.warning("Rejected β X-ML-API-Key header absent")
|
| raise HTTPException(status_code=401,
|
| detail="X-ML-API-Key header is required")
|
| if x_ml_api_key != API_KEY:
|
| log.warning("Rejected β X-ML-API-Key invalid")
|
| raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
| allowed, retry_after = _check_rate_limit(x_ml_api_key)
|
| if not allowed:
|
| raise HTTPException(
|
| status_code=429,
|
| detail="Rate limit exceeded",
|
| headers={"Retry-After": str(retry_after)},
|
| )
|
|
|
|
|
|
|
|
|
|
|
| @app.get("/")
|
| async def root():
|
| return {
|
| "service": "ImmuniWatch Nigeria ML Service",
|
| "version": MODEL_VERSION,
|
| "status": "running",
|
| "docs": "/docs",
|
| "health": "/health",
|
| "dashboard": "/dashboard",
|
| "classify": "POST /classify",
|
| "batch": "POST /classify/batch",
|
| }
|
|
|
|
|
|
|
|
|
|
|
| @app.get("/dashboard", include_in_schema=False)
|
| async def dashboard():
|
| path = Path("dashboard.html")
|
| if not path.exists():
|
| raise HTTPException(status_code=404, detail="Dashboard not found")
|
| return FileResponse(path, media_type="text/html")
|
|
|
|
|
|
|
|
|
|
|
| @app.get("/health")
|
| async def health():
|
| from src.models.classifier import is_loaded
|
|
|
| if not is_loaded():
|
| return JSONResponse(
|
| status_code=503,
|
| content={"status": "unavailable", "reason": "model loading"},
|
| )
|
| return {
|
| "status": "ok",
|
| "model_loaded": True,
|
| "model_version": MODEL_VERSION,
|
| "device": DEVICE,
|
| "uptime_s": int(time.time() - _start_time),
|
| }
|
|
|
|
|
|
|
|
|
|
|
| from src.api.routes import router
|
|
|
| app.include_router(router, dependencies=[Depends(require_api_key)]) |