| """ |
| app.py |
| ------ |
| FastAPI service exposing the Reddit upvotes prediction pipeline. |
| |
| Endpoints |
| --------- |
| POST /predict_single — single prediction via hours_ago and post_utc |
| POST /predict_batch — batch prediction via lists of hours_ago and post_utc |
| GET /health — liveness check |
| """ |
|
|
| import sys |
| import os |
| from pathlib import Path |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent |
| sys.path.append(str(PROJECT_ROOT)) |
| sys.path.append(str(PROJECT_ROOT / "models")) |
|
|
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field, field_validator |
| import uvicorn |
| from typing import List, Dict, Any |
|
|
| from models.reddit_predict import predict_single, predict_batch |
|
|
| app = FastAPI( |
| title="Reddit Upvotes Predictor", |
| description="Predicts the number of upvotes a Reddit post will receive.", |
| version="1.0.0", |
| ) |
|
|
| |
| LOCAL_ARTIFACTS_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth" |
|
|
| |
| |
| |
|
|
| class PostContent(BaseModel): |
| """Shared post content fields, inherited by all request models.""" |
| subreddit: str = Field(..., description="Subreddit name (e.g. 'MachineLearning')") |
| title: str = Field(..., min_length=1, description="Post title") |
| text: str = Field(..., min_length=1, description="Post body text") |
|
|
|
|
| class PredictSingleRequest(PostContent): |
| """Single prediction.""" |
| hours_ago: float |
| post_utc: float |
|
|
|
|
| class PredictBatchRequest(PostContent): |
| """Batch prediction.""" |
| hours_ago_list: List[float] = Field(..., min_length=1) |
| post_utc_list: List[float] = Field(..., min_length=1) |
|
|
| @field_validator("post_utc_list") |
| @classmethod |
| def validate_lengths(cls, v: List[float], info: Any) -> List[float]: |
| if 'hours_ago_list' in info.data and len(v) != len(info.data['hours_ago_list']): |
| raise ValueError("hours_ago_list and post_utc_list must have the same length.") |
| return v |
|
|
|
|
| class PredictSingleResponse(BaseModel): |
| prediction: float |
|
|
|
|
| class BatchPredictionItem(BaseModel): |
| hours_ago: float |
| post_utc: float |
| prediction: float |
|
|
|
|
| class PredictBatchResponse(BaseModel): |
| predictions: List[BatchPredictionItem] |
| count: int |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health", tags=["Meta"]) |
| def health(): |
| """Liveness check.""" |
| return {"status": "ok"} |
|
|
|
|
| @app.post( |
| "/predict_single", |
| response_model=PredictSingleResponse, |
| summary="Single prediction with mutual specific timestamp", |
| tags=["Predict"], |
| ) |
| def api_predict_single(request: PredictSingleRequest): |
| """ |
| Predict upvotes for a single post context given both hours_ago and post_utc. |
| """ |
| try: |
| result = predict_single( |
| subreddit=request.subreddit, |
| title=request.title, |
| text=request.text, |
| hours_ago=request.hours_ago, |
| post_utc=request.post_utc, |
| model_path=LOCAL_ARTIFACTS_MODEL_PATH |
| ) |
| except (ValueError, FileNotFoundError) as e: |
| raise HTTPException(status_code=422, detail=str(e)) |
|
|
| return PredictSingleResponse(prediction=result) |
|
|
|
|
| @app.post( |
| "/predict_batch", |
| response_model=PredictBatchResponse, |
| summary="Batch prediction with matching length lists", |
| tags=["Predict"], |
| ) |
| def api_predict_batch(request: PredictBatchRequest): |
| """ |
| Predict upvotes for an array of potential posting times. |
| """ |
| try: |
| results = predict_batch( |
| subreddit=request.subreddit, |
| title=request.title, |
| text=request.text, |
| hours_ago_list=request.hours_ago_list, |
| post_utc_list=request.post_utc_list, |
| model_path=LOCAL_ARTIFACTS_MODEL_PATH |
| ) |
| except (ValueError, FileNotFoundError) as e: |
| raise HTTPException(status_code=422, detail=str(e)) |
|
|
| return PredictBatchResponse(predictions=results, count=len(results)) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|