# app.py from __future__ import annotations import warnings from typing import List, Literal, Optional, Tuple from config import MODEL_PATH, REAL_LABEL, API_KEY import joblib from fastapi import FastAPI, Header, HTTPException from helper import _combine from schemas import PredictOut, PredictBatchIn, PredictIn, PredictBatchOut # Suppress sklearn version warnings warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") warnings.filterwarnings("ignore", message=".*InconsistentVersionWarning.*") # ========================= # Load calibrated model # (Pipeline: TF-IDF -> CalibratedClassifierCV(LinearSVC)) # ========================= # Additional specific suppression for sklearn version warnings try: from sklearn.exceptions import InconsistentVersionWarning warnings.filterwarnings("ignore", category=InconsistentVersionWarning) except ImportError: # Fallback for older sklearn versions pass # Guard against double loading if 'PIPE' not in globals(): try: print("Loading model from:", MODEL_PATH) with warnings.catch_warnings(): warnings.simplefilter("ignore") PIPE = joblib.load(MODEL_PATH) print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") raise # Lấy thứ tự class từ estimator cuối để map xác suất cho chắc try: classes = list(PIPE.named_steps["clf"].classes_) except Exception: classes = list(getattr(PIPE, "classes_", [0, 1])) # fallback print(f"Model classes: {classes}") IDX_REAL = classes.index(REAL_LABEL) IDX_FAKE = classes.index(0) print(f"Real index: {IDX_REAL}, Fake index: {IDX_FAKE}") else: print("Model already loaded, skipping reload...") # ========================= # Core inference # ========================= def infer_one(inp: PredictIn) -> PredictOut: text_all = inp.text_all.strip().lower() if inp.text_all else _combine(inp.title, inp.text) # Suppress warnings during prediction with warnings.catch_warnings(): warnings.simplefilter("ignore") probs = PIPE.predict_proba([text_all])[0] prob_real = float(probs[IDX_REAL]) prob_fake = float(probs[IDX_FAKE]) label = "real" if prob_real >= 0.5 else "fake" return PredictOut( label=label, prob_real=prob_real, prob_fake=prob_fake, ) def infer_batch(items: List[PredictIn]) -> List[PredictOut]: return [infer_one(x) for x in items] # ========================= # FastAPI endpoints # ========================= app = FastAPI( title="SVM Fake/Real News Classifier", description="API for classifying news as real or fake using SVM with TF-IDF features", version="1.0.0" ) @app.get("/") def root(): return { "message": "SVM Fake/Real News Classifier API", "endpoints": { "predict": "/predict", "predict_batch": "/predict_batch", "health": "/health" }, "model_info": { "classes": ["fake", "real"], "model_path": MODEL_PATH, "calibrated": True } } @app.get("/health") def health_check(): return {"status": "healthy", "model_loaded": 'PIPE' in globals()} @app.post("/predict", response_model=PredictOut) def predict(payload: PredictIn, x_api_key: str = Header(default="")): if x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Unauthorized") return infer_one(payload) @app.post("/predict_batch", response_model=PredictBatchOut) def predict_batch(payload: PredictBatchIn, x_api_key: str = Header(default="")): if x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Unauthorized") return PredictBatchOut(results=infer_batch(payload.items)) if __name__ == "__main__": import uvicorn print("===== Application Ready =====") print("FastAPI app initialized successfully") print("API endpoints available at /predict and /predict_batch") print("API documentation at /docs") print("================================") uvicorn.run(app, host="0.0.0.0", port=6778)