""" app.py ------ FastAPI service exposing the Reddit upvotes prediction pipeline. Endpoints --------- POST /predict_single — single prediction via hours_ago and post_utc POST /predict_batch — batch prediction via lists of hours_ago and post_utc GET /health — liveness check """ import sys import os from pathlib import Path # Fix sys.path so we can import the refactored code correctly PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent sys.path.append(str(PROJECT_ROOT)) sys.path.append(str(PROJECT_ROOT / "models")) from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field, field_validator import uvicorn from typing import List, Dict, Any from models.reddit_predict import predict_single, predict_batch app = FastAPI( title="Reddit Upvotes Predictor", description="Predicts the number of upvotes a Reddit post will receive.", version="1.0.0", ) # Reference the artifacts directly in this specific 'test' directory as requested LOCAL_ARTIFACTS_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth" # --------------------------------------------------------------------------- # Request / Response schemas # --------------------------------------------------------------------------- class PostContent(BaseModel): """Shared post content fields, inherited by all request models.""" subreddit: str = Field(..., description="Subreddit name (e.g. 'MachineLearning')") title: str = Field(..., min_length=1, description="Post title") text: str = Field(..., min_length=1, description="Post body text") class PredictSingleRequest(PostContent): """Single prediction.""" hours_ago: float post_utc: float class PredictBatchRequest(PostContent): """Batch prediction.""" hours_ago_list: List[float] = Field(..., min_length=1) post_utc_list: List[float] = Field(..., min_length=1) @field_validator("post_utc_list") @classmethod def validate_lengths(cls, v: List[float], info: Any) -> List[float]: if 'hours_ago_list' in info.data and len(v) != len(info.data['hours_ago_list']): raise ValueError("hours_ago_list and post_utc_list must have the same length.") return v class PredictSingleResponse(BaseModel): prediction: float class BatchPredictionItem(BaseModel): hours_ago: float post_utc: float prediction: float class PredictBatchResponse(BaseModel): predictions: List[BatchPredictionItem] count: int # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/health", tags=["Meta"]) def health(): """Liveness check.""" return {"status": "ok"} @app.post( "/predict_single", response_model=PredictSingleResponse, summary="Single prediction with mutual specific timestamp", tags=["Predict"], ) def api_predict_single(request: PredictSingleRequest): """ Predict upvotes for a single post context given both hours_ago and post_utc. """ try: result = predict_single( subreddit=request.subreddit, title=request.title, text=request.text, hours_ago=request.hours_ago, post_utc=request.post_utc, model_path=LOCAL_ARTIFACTS_MODEL_PATH ) except (ValueError, FileNotFoundError) as e: raise HTTPException(status_code=422, detail=str(e)) return PredictSingleResponse(prediction=result) @app.post( "/predict_batch", response_model=PredictBatchResponse, summary="Batch prediction with matching length lists", tags=["Predict"], ) def api_predict_batch(request: PredictBatchRequest): """ Predict upvotes for an array of potential posting times. """ try: results = predict_batch( subreddit=request.subreddit, title=request.title, text=request.text, hours_ago_list=request.hours_ago_list, post_utc_list=request.post_utc_list, model_path=LOCAL_ARTIFACTS_MODEL_PATH ) except (ValueError, FileNotFoundError) as e: raise HTTPException(status_code=422, detail=str(e)) return PredictBatchResponse(predictions=results, count=len(results)) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860) # git add # git commit -m # git push # git remote -v ??? # git config --global credential.helper store # uvicorn app:app --reload # git commit conventions? # https://cyrilfrl-test.hf.space/greet # git commit conventions # curl -X POST https://cyrilfrl-test.hf.space/greet -H "Content-Type: application/json" -d '{"name": "Cyril"}' # (dev) C:\Users\Cyril\Desktop\B3Q2\startup\app\HuggingFace\RedditPredictUpvotes\test>curl -X POST https://cyrilfrl-test.hf.space/greet -H "Content-Type: application/json" -d "{\"name\": \"Cyril\"}" # {"message":"Hello Cyril!!"} # (dev) C:\Users\Cyril\Desktop\B3Q2\startup\app\HuggingFace\RedditPredictUpvotes\test> # curl -X POST "http://127.0.0.1:8000/greet" -H "Content-Type: application/json" -d "{\"name\": \"Cyril\"}"