|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import warnings |
|
|
from typing import List, Literal, Optional, Tuple |
|
|
from config import MODEL_PATH, REAL_LABEL, API_KEY |
|
|
import joblib |
|
|
from fastapi import FastAPI, Header, HTTPException |
|
|
from helper import _combine |
|
|
from schemas import PredictOut, PredictBatchIn, PredictIn, PredictBatchOut |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") |
|
|
warnings.filterwarnings("ignore", message=".*InconsistentVersionWarning.*") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from sklearn.exceptions import InconsistentVersionWarning |
|
|
warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
|
|
except ImportError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if 'PIPE' not in globals(): |
|
|
try: |
|
|
print("Loading model from:", MODEL_PATH) |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
PIPE = joblib.load(MODEL_PATH) |
|
|
print("Model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
classes = list(PIPE.named_steps["clf"].classes_) |
|
|
except Exception: |
|
|
classes = list(getattr(PIPE, "classes_", [0, 1])) |
|
|
|
|
|
print(f"Model classes: {classes}") |
|
|
IDX_REAL = classes.index(REAL_LABEL) |
|
|
IDX_FAKE = classes.index(0) |
|
|
print(f"Real index: {IDX_REAL}, Fake index: {IDX_FAKE}") |
|
|
else: |
|
|
print("Model already loaded, skipping reload...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_one(inp: PredictIn) -> PredictOut: |
|
|
text_all = inp.text_all.strip().lower() if inp.text_all else _combine(inp.title, inp.text) |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
probs = PIPE.predict_proba([text_all])[0] |
|
|
|
|
|
prob_real = float(probs[IDX_REAL]) |
|
|
prob_fake = float(probs[IDX_FAKE]) |
|
|
|
|
|
label = "real" if prob_real >= 0.5 else "fake" |
|
|
|
|
|
return PredictOut( |
|
|
label=label, |
|
|
prob_real=prob_real, |
|
|
prob_fake=prob_fake, |
|
|
) |
|
|
|
|
|
|
|
|
def infer_batch(items: List[PredictIn]) -> List[PredictOut]: |
|
|
return [infer_one(x) for x in items] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="SVM Fake/Real News Classifier", |
|
|
description="API for classifying news as real or fake using SVM with TF-IDF features", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return { |
|
|
"message": "SVM Fake/Real News Classifier API", |
|
|
"endpoints": { |
|
|
"predict": "/predict", |
|
|
"predict_batch": "/predict_batch", |
|
|
"health": "/health" |
|
|
}, |
|
|
"model_info": { |
|
|
"classes": ["fake", "real"], |
|
|
"model_path": MODEL_PATH, |
|
|
"calibrated": True |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return {"status": "healthy", "model_loaded": 'PIPE' in globals()} |
|
|
|
|
|
@app.post("/predict", response_model=PredictOut) |
|
|
def predict(payload: PredictIn, x_api_key: str = Header(default="")): |
|
|
if x_api_key != API_KEY: |
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
return infer_one(payload) |
|
|
|
|
|
@app.post("/predict_batch", response_model=PredictBatchOut) |
|
|
def predict_batch(payload: PredictBatchIn, x_api_key: str = Header(default="")): |
|
|
if x_api_key != API_KEY: |
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
return PredictBatchOut(results=infer_batch(payload.items)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
print("===== Application Ready =====") |
|
|
print("FastAPI app initialized successfully") |
|
|
print("API endpoints available at /predict and /predict_batch") |
|
|
print("API documentation at /docs") |
|
|
print("================================") |
|
|
uvicorn.run(app, host="0.0.0.0", port=6778) |
|
|
|