| """
|
| app.py
|
| ------
|
| FastAPI service exposing the Reddit upvotes prediction pipeline.
|
|
|
| Endpoints
|
| ---------
|
| POST /predict_single — single prediction via hours_ago and post_utc
|
| POST /predict_batch — batch prediction via lists of hours_ago and post_utc
|
| GET /health — liveness check
|
| """
|
|
|
| import sys
|
| import os
|
| from pathlib import Path
|
|
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent
|
| sys.path.insert(0, str(PROJECT_ROOT))
|
| sys.path.insert(0, str(PROJECT_ROOT / "models"))
|
|
|
| from fastapi import FastAPI, HTTPException
|
| from pydantic import BaseModel, Field, field_validator
|
| import uvicorn
|
| from typing import List, Dict, Any
|
|
|
| from models.reddit_predict import predict_single, predict_batch
|
|
|
| app = FastAPI(
|
| title="Reddit Upvotes Predictor",
|
| description="Predicts the number of upvotes a Reddit post will receive.",
|
| version="1.0.0",
|
| )
|
|
|
|
|
| LOCAL_ARTIFACTS_MODEL_PATH = Path(__file__).resolve().parent / "models" / "artifacts" / "model_reddit_final.pth"
|
|
|
|
|
|
|
|
|
|
|
| class PostContent(BaseModel):
|
| """Shared post content fields, inherited by all request models."""
|
| subreddit: str = Field(..., description="Subreddit name (e.g. 'MachineLearning')")
|
| title: str = Field(..., min_length=1, description="Post title")
|
| text: str = Field(..., min_length=1, description="Post body text")
|
|
|
|
|
| class PredictSingleRequest(PostContent):
|
| """Single prediction."""
|
| hours_ago: float
|
| post_utc: float
|
|
|
|
|
| class PredictBatchRequest(PostContent):
|
| """Batch prediction."""
|
| hours_ago_list: List[float] = Field(..., min_length=1)
|
| post_utc_list: List[float] = Field(..., min_length=1)
|
|
|
| @field_validator("post_utc_list")
|
| @classmethod
|
| def validate_lengths(cls, v: List[float], info: Any) -> List[float]:
|
| if 'hours_ago_list' in info.data and len(v) != len(info.data['hours_ago_list']):
|
| raise ValueError("hours_ago_list and post_utc_list must have the same length.")
|
| return v
|
|
|
|
|
| class PredictSingleResponse(BaseModel):
|
| prediction: float
|
|
|
|
|
| class BatchPredictionItem(BaseModel):
|
| hours_ago: float
|
| post_utc: float
|
| prediction: float
|
|
|
|
|
| class PredictBatchResponse(BaseModel):
|
| predictions: List[BatchPredictionItem]
|
| count: int
|
|
|
|
|
|
|
|
|
|
|
|
|
| @app.get("/health", tags=["Meta"])
|
| def health():
|
| """Liveness check."""
|
| return {"status": "ok"}
|
|
|
|
|
| @app.post(
|
| "/predict_single",
|
| response_model=PredictSingleResponse,
|
| summary="Single prediction with mutual specific timestamp",
|
| tags=["Predict"],
|
| )
|
| def api_predict_single(request: PredictSingleRequest):
|
| """
|
| Predict upvotes for a single post context given both hours_ago and post_utc.
|
| """
|
| try:
|
| result = predict_single(
|
| subreddit=request.subreddit,
|
| title=request.title,
|
| text=request.text,
|
| hours_ago=request.hours_ago,
|
| post_utc=request.post_utc,
|
| model_path=LOCAL_ARTIFACTS_MODEL_PATH
|
| )
|
| except (ValueError, FileNotFoundError) as e:
|
| raise HTTPException(status_code=422, detail=str(e))
|
|
|
| return PredictSingleResponse(prediction=result)
|
|
|
|
|
| @app.post(
|
| "/predict_batch",
|
| response_model=PredictBatchResponse,
|
| summary="Batch prediction with matching length lists",
|
| tags=["Predict"],
|
| )
|
| def api_predict_batch(request: PredictBatchRequest):
|
| """
|
| Predict upvotes for an array of potential posting times.
|
| """
|
| try:
|
| results = predict_batch(
|
| subreddit=request.subreddit,
|
| title=request.title,
|
| text=request.text,
|
| hours_ago_list=request.hours_ago_list,
|
| post_utc_list=request.post_utc_list,
|
| model_path=LOCAL_ARTIFACTS_MODEL_PATH
|
| )
|
| except (ValueError, FileNotFoundError) as e:
|
| raise HTTPException(status_code=422, detail=str(e))
|
|
|
| return PredictBatchResponse(predictions=results, count=len(results))
|
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|